From bef8f593ae820eb8465934de91eb27468edf6444 Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Sun, 2 Feb 2020 20:30:35 +0000 Subject: [PATCH 001/608] Ensure all TXT, SRV, A records are unique Fixes issues with shared records being used where they shouldn't be. PTR records should be shared, but SRV, TXT and A/AAAA records should be unique. Whilst mDNS and DNS-SD in theory support shared records for these types of record, they are not implemented in python-zeroconf at the moment. See zeroconf.check_service() method which verifies the service is unique on the network before registering. --- zeroconf/__init__.py | 57 +++++++++++++++++++++++++++++++++++++------- zeroconf/test.py | 54 ++++++++++++++++++++++++++++++++--------- 2 files changed, 91 insertions(+), 20 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 35fac80a..b3df8180 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -887,6 +887,16 @@ def add_question(self, record: DNSQuestion) -> None: self.questions.append(record) def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: + + """Only support for unique answers""" + if ( + record.type == _TYPE_TXT + or record.type == _TYPE_SRV + or record.type == _TYPE_A + or record.type == _TYPE_AAAA + ): + assert record.unique + """Adds an answer""" if not record.suppressed_by(inp): self.add_answer_at_time(record, 0) @@ -894,6 +904,16 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: """Adds an answer if it does not expire by a certain time""" if record is not None: + + """Only support for unique answers""" + if ( + record.type == _TYPE_TXT + or record.type == _TYPE_SRV + or record.type == _TYPE_A + or record.type == _TYPE_AAAA + ): + assert record.unique + if now == 0 or not record.is_expired(now): self.answers.append((record, now)) @@ -937,6 +957,15 @@ def add_additional_answer(self, record: DNSRecord) -> None: o All address records (type "A" and "AAAA") named in the SRV rdata. """ + """Only support for unique answers""" + if ( + record.type == _TYPE_TXT + or record.type == _TYPE_SRV + or record.type == _TYPE_A + or record.type == _TYPE_AAAA + ): + assert record.unique + self.additionals.append(record) def pack(self, format_: Union[bytes, str], value: Any) -> None: @@ -2244,7 +2273,7 @@ def _broadcast_service(self, info: ServiceInfo) -> None: DNSService( info.name, _TYPE_SRV, - _CLASS_IN, + _CLASS_IN | _CLASS_UNIQUE, info.host_ttl, info.priority, info.weight, @@ -2254,10 +2283,14 @@ def _broadcast_service(self, info: ServiceInfo) -> None: 0, ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, info.other_ttl, info.text), 0) + out.add_answer_at_time( + DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, info.other_ttl, info.text), 0 + ) for address in info.addresses_by_version(IPVersion.All): type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, info.host_ttl, address), 0) + out.add_answer_at_time( + DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, info.host_ttl, address), 0 + ) self.send(out) i += 1 next_time += _REGISTER_TIME @@ -2286,7 +2319,7 @@ def unregister_service(self, info: ServiceInfo) -> None: DNSService( info.name, _TYPE_SRV, - _CLASS_IN, + _CLASS_IN | _CLASS_UNIQUE, 0, info.priority, info.weight, @@ -2295,11 +2328,13 @@ def unregister_service(self, info: ServiceInfo) -> None: ), 0, ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) + out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, 0, info.text), 0) for address in info.addresses_by_version(IPVersion.All): type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) + out.add_answer_at_time( + DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, 0, address), 0 + ) self.send(out) i += 1 next_time += _UNREGISTER_TIME @@ -2322,7 +2357,7 @@ def unregister_all_services(self) -> None: DNSService( info.name, _TYPE_SRV, - _CLASS_IN, + _CLASS_IN | _CLASS_UNIQUE, 0, info.priority, info.weight, @@ -2331,10 +2366,14 @@ def unregister_all_services(self) -> None: ), 0, ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN, 0, info.text), 0) + out.add_answer_at_time( + DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, 0, info.text), 0 + ) for address in info.addresses_by_version(IPVersion.All): type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time(DNSAddress(info.server, type_, _CLASS_IN, 0, address), 0) + out.add_answer_at_time( + DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, 0, address), 0 + ) self.send(out) i += 1 next_time += _UNREGISTER_TIME diff --git a/zeroconf/test.py b/zeroconf/test.py index 5a89c9f6..3b381f60 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -146,7 +146,17 @@ def test_parse_own_packet_question(self): def test_parse_own_packet_response(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) generated.add_answer_at_time( - r.DNSService("æøå.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local."), 0 + r.DNSService( + "æøå.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, ) parsed = r.DNSIncoming(generated.packet()) self.assertEqual(len(generated.answers), 1) @@ -166,13 +176,34 @@ def test_suppress_answer(self): question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) query_generated.add_question(question) answer1 = r.DNSService( - "testname1.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." + "testname1.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", ) staleanswer2 = r.DNSService( - "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL / 2, 0, 0, 80, "foo.local." + "testname2.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL / 2, + 0, + 0, + 80, + "foo.local.", ) answer2 = r.DNSService( - "testname2.local.", r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, "foo.local." + "testname2.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", ) query_generated.add_answer_at_time(answer1, 0) query_generated.add_answer_at_time(staleanswer2, 0) @@ -444,7 +475,8 @@ def generate_host(zc, host_name, type_): out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) out.add_answer_at_time(r.DNSPointer(type_, r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, name), 0) out.add_answer_at_time( - r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, name), 0 + r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, r._DNS_HOST_TTL, 0, 0, 80, name), + 0, ) zc.send(out) @@ -487,7 +519,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ttl = 0 generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0 + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 ) generated.add_answer_at_time( r.DNSService( @@ -670,7 +702,7 @@ def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) generated = r.DNSOutgoing(0) - answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN, 1, packed) + answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN | r._CLASS_UNIQUE, 1, packed) generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) @@ -886,6 +918,7 @@ def test_integration_with_listener_class(self): service_added = Event() service_removed = Event() service_updated = Event() + service_updated2 = Event() subtype_name = "My special Subtype" type_ = "_http._tcp.local." @@ -902,7 +935,7 @@ def remove_service(self, zeroconf, type, name): service_removed.set() def update_service(self, zeroconf, type, name): - pass + service_updated2.set() class MySubListener(r.ServiceListener): def add_service(self, zeroconf, type, name): @@ -966,7 +999,7 @@ def update_service(self, zeroconf, type, name): assert info is not None assert info.properties[b'prop_none'] is False - # Begin material test addition + # test TXT record update sublistener = MySubListener() zeroconf_browser.add_service_listener(registration_name, sublistener) properties['prop_blank'] = b'an updated string' @@ -981,7 +1014,6 @@ def update_service(self, zeroconf, type, name): info = zeroconf_browser.get_service_info(type_, registration_name) assert info is not None assert info.properties[b'prop_blank'] == properties['prop_blank'] - # End material test addition zeroconf_registrar.unregister_service(info_service) service_removed.wait(1) @@ -1043,7 +1075,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ttl = 0 generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_name), 0 + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 ) generated.add_answer_at_time( r.DNSService( From ca8e53de55a563f5c7049be2eda14ae0ecd1a7cf Mon Sep 17 00:00:00 2001 From: Andreas Oberritter Date: Wed, 19 Feb 2020 21:35:38 +0100 Subject: [PATCH 002/608] Do not exclude interfaces with host-only netmasks from InterfaceChoice.All (#227) Host-only netmasks do not forbid multicast. Tested on Debian 10 running in Qubes and on Ubuntu 18.04. --- zeroconf/__init__.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b3df8180..5554f9b2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1867,14 +1867,7 @@ def find( def get_all_addresses() -> List[str]: - return list( - set( - addr.ip - for iface in ifaddr.get_adapters() - for addr in iface.ips - if addr.is_IPv4 and addr.network_prefix != 32 # Host only netmask 255.255.255.255 - ) - ) + return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4)) def get_all_addresses_v6() -> List[int]: From f6690d2048cb87cb0fb3a7c3b832cf1a1f40e61a Mon Sep 17 00:00:00 2001 From: Aldo Hoeben Date: Thu, 20 Feb 2020 13:45:38 +0100 Subject: [PATCH 003/608] Fix representation of IPv6 DNSAddress (#230) --- zeroconf/__init__.py | 6 +++++- zeroconf/test.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5554f9b2..1974de83 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -540,7 +540,11 @@ def __ne__(self, other: Any) -> bool: def __repr__(self) -> str: """String representation""" try: - return self.to_string(str(socket.inet_ntoa(self.address))) + return self.to_string( + socket.inet_ntop( + socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address + ) + ) except Exception: # TODO stop catching all Exceptions return self.to_string(str(self.address)) diff --git a/zeroconf/test.py b/zeroconf/test.py index 3b381f60..4ff585cc 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -63,7 +63,17 @@ def test_dns_pointer_repr(self): def test_dns_address_repr(self): address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - repr(address) + assert repr(address).endswith("b'a'") + + address_ipv4 = r.DNSAddress( + 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') + ) + assert repr(address_ipv4).endswith('127.0.0.1') + + address_ipv6 = r.DNSAddress( + 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') + ) + assert repr(address_ipv6).endswith('::1') def test_dns_question_repr(self): question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE) From 5e4f496778d91ccfc65e946d3d94c39ab6388b29 Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Mon, 3 Feb 2020 20:18:05 +0000 Subject: [PATCH 004/608] Refactor out unique assertion --- zeroconf/__init__.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1974de83..60893b39 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -886,21 +886,15 @@ class State(enum.Enum): init = 0 finished = 1 + @staticmethod + def is_type_unique(type_: int) -> bool: + return type_ == _TYPE_TXT or type_ == _TYPE_SRV or type_ == _TYPE_A or type_ == _TYPE_AAAA + def add_question(self, record: DNSQuestion) -> None: """Adds a question""" self.questions.append(record) def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: - - """Only support for unique answers""" - if ( - record.type == _TYPE_TXT - or record.type == _TYPE_SRV - or record.type == _TYPE_A - or record.type == _TYPE_AAAA - ): - assert record.unique - """Adds an answer""" if not record.suppressed_by(inp): self.add_answer_at_time(record, 0) @@ -909,13 +903,7 @@ def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int] """Adds an answer if it does not expire by a certain time""" if record is not None: - """Only support for unique answers""" - if ( - record.type == _TYPE_TXT - or record.type == _TYPE_SRV - or record.type == _TYPE_A - or record.type == _TYPE_AAAA - ): + if self.is_type_unique(record.type): assert record.unique if now == 0 or not record.is_expired(now): @@ -961,13 +949,7 @@ def add_additional_answer(self, record: DNSRecord) -> None: o All address records (type "A" and "AAAA") named in the SRV rdata. """ - """Only support for unique answers""" - if ( - record.type == _TYPE_TXT - or record.type == _TYPE_SRV - or record.type == _TYPE_A - or record.type == _TYPE_AAAA - ): + if self.is_type_unique(record.type): assert record.unique self.additionals.append(record) From d8caa4e2d71025ed42b33abb4d329329437b44fb Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Sun, 16 Feb 2020 15:56:27 +0000 Subject: [PATCH 005/608] Remove duplciate update messages sent to listeners The prior code used to send updates even when the new record was identical to the old. This resulted in duplciate update messages when there was in fact no update (apart from TTL refresh) --- zeroconf/__init__.py | 10 +++++++++- zeroconf/test.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 60893b39..b18b51ff 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2431,8 +2431,15 @@ def handle_response(self, msg: DNSIncoming) -> None: are held in the cache, and listeners are notified.""" now = current_time_millis() for record in msg.answers: + + updated = True + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 for entry in self.cache.entries(): + + if entry == record: + updated = False + if DNSEntry.__eq__(entry, record) and (record.created - entry.created > 1000): self.cache.remove(entry) @@ -2443,7 +2450,8 @@ def handle_response(self, msg: DNSIncoming) -> None: maybe_entry.reset_ttl(record) else: self.cache.add(record) - self.update_record(now, record) + if updated: + self.update_record(now, record) else: if maybe_entry is not None: self.update_record(now, record) diff --git a/zeroconf/test.py b/zeroconf/test.py index 4ff585cc..5202ce97 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1125,6 +1125,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi service_updated_event.clear() service_text = b'path=/~humingchun/' zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(1) assert service_added is True assert service_updated_count == 2 From 1ca023fae4b586679446ceaf3e2e9955ea5bf180 Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Thu, 20 Feb 2020 13:02:43 +0000 Subject: [PATCH 006/608] Support cooperating responders (#224) --- zeroconf/__init__.py | 39 ++++++++++++++++++++++++--------------- zeroconf/test.py | 3 +++ 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b18b51ff..34329c67 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2205,18 +2205,24 @@ def remove_all_service_listeners(self) -> None: self.remove_service_listener(listener) def register_service( - self, info: ServiceInfo, ttl: Optional[int] = None, allow_name_change: bool = False + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, ) -> None: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service. The name of the service may be changed if needed to make - it unique on the network.""" + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" if ttl is not None: # ttl argument is used to maintain backward compatibility # Setting TTLs via ServiceInfo is preferred info.host_ttl = ttl info.other_ttl = ttl - self.check_service(info, allow_name_change) + self.check_service(info, allow_name_change, cooperating_responders) self.services[info.name.lower()] = info if info.type in self.servicetypes: self.servicetypes[info.type] += 1 @@ -2357,7 +2363,9 @@ def unregister_all_services(self) -> None: i += 1 next_time += _UNREGISTER_TIME - def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None: + def check_service( + self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False + ) -> None: """Checks the network for a unique service name, modifying the ServiceInfo passed in if it is not unique.""" @@ -2374,17 +2382,18 @@ def check_service(self, info: ServiceInfo, allow_name_change: bool) -> None: next_time = now i = 0 while i < 3: - # check for a name conflict - while self.cache.current_entry_with_name_and_alias(info.type, info.name): - if not allow_name_change: - raise NonUniqueNameException - - # change the name and look for a conflict - info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) - next_instance_number += 1 - service_type_name(info.name) - next_time = now - i = 0 + if not cooperating_responders: + # check for a name conflict + while self.cache.current_entry_with_name_and_alias(info.type, info.name): + if not allow_name_change: + raise NonUniqueNameException + + # change the name and look for a conflict + info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) + next_instance_number += 1 + service_type_name(info.name) + next_time = now + i = 0 if now < next_time: self.wait(next_time - now) diff --git a/zeroconf/test.py b/zeroconf/test.py index 5202ce97..fb87524b 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -463,6 +463,9 @@ def verify_name_change(self, zc, type_, name, number_hosts): # verify name conflict self.assertRaises(r.NonUniqueNameException, zc.register_service, info_service) + # verify no name conflict https://tools.ietf.org/html/rfc6762#section-6.6 + zc.register_service(info_service, cooperating_responders=True) + zc.register_service(info_service, allow_name_change=True) assert info_service.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) From 37fa0a0d59a5b5d09295a462bf911e82d2d770ed Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 Feb 2020 14:41:10 -0600 Subject: [PATCH 007/608] Optimize handle_response cache check The handle_response loop would encounter a unique record it would search the cache in order to remove keys that matched the DNSEntry for the record. Since the cache is stored as a list of records with the key as the record name, we can avoid searching the entire cache each time and on search for the DNSEntry of the record. In practice this means with 5000 entries and records in the cache we now only need to search 4 or 5. When looping over the cache entries for the name, we now check the expire time first as its cheaper than calling DNSEntry.__eq__ Test environment: Home Assistant running on home networking with a /22 and a significant amount of broadcast traffic Testing was done with py-spy v0.3.3 (https://github.com/benfred/py-spy/releases) # py-spy top --pid Before: ``` Collecting samples from '/usr/local/bin/python3 -m homeassistant --config /config' (python v3.7.6) Total Samples 10200 GIL: 0.00%, Active: 0.00%, Threads: 35 %Own %Total OwnTime TotalTime Function (filename:line) 0.00% 0.00% 18.13s 18.13s _worker (concurrent/futures/thread.py:78) 0.00% 0.00% 2.51s 2.56s run (zeroconf/__init__.py:1221) 0.00% 0.00% 0.420s 0.420s __eq__ (zeroconf/__init__.py:394) 0.00% 0.00% 0.390s 0.390s handle_read (zeroconf/__init__.py:1260) 0.00% 0.00% 0.240s 0.670s handle_response (zeroconf/__init__.py:2452) 0.00% 0.00% 0.230s 0.230s __eq__ (zeroconf/__init__.py:606) 0.00% 0.00% 0.200s 0.810s handle_response (zeroconf/__init__.py:2449) 0.00% 0.00% 0.140s 0.150s __eq__ (zeroconf/__init__.py:632) 0.00% 0.00% 0.130s 0.130s entries (zeroconf/__init__.py:1185) 0.00% 0.00% 0.090s 0.090s notify (threading.py:352) 0.00% 0.00% 0.080s 0.080s read_utf (zeroconf/__init__.py:818) 0.00% 0.00% 0.080s 0.080s __eq__ (zeroconf/__init__.py:678) 0.00% 0.00% 0.070s 0.080s __eq__ (zeroconf/__init__.py:533) 0.00% 0.00% 0.060s 0.060s __eq__ (zeroconf/__init__.py:677) 0.00% 0.00% 0.050s 0.050s get (zeroconf/__init__.py:1146) 0.00% 0.00% 0.050s 0.050s do_commit (sqlalchemy/engine/default.py:541) 0.00% 0.00% 0.040s 2.86s run (zeroconf/__init__.py:1226) ``` After ``` Collecting samples from '/usr/local/bin/python3 -m homeassistant --config /config' (python v3.7.6) Total Samples 10200 GIL: 7.00%, Active: 61.00%, Threads: 35 %Own %Total OwnTime TotalTime Function (filename:line) 47.00% 47.00% 24.84s 24.84s _worker (concurrent/futures/thread.py:78) 5.00% 5.00% 2.97s 2.97s run (zeroconf/__init__.py:1226) 1.00% 1.00% 0.390s 0.390s handle_read (zeroconf/__init__.py:1265) 1.00% 1.00% 0.200s 0.200s read_utf (zeroconf/__init__.py:818) 0.00% 0.00% 0.120s 0.120s unpack (zeroconf/__init__.py:723) 0.00% 1.00% 0.120s 0.320s read_name (zeroconf/__init__.py:834) 0.00% 0.00% 0.100s 0.240s update_record (zeroconf/__init__.py:2440) 0.00% 0.00% 0.090s 0.090s notify (threading.py:352) 0.00% 0.00% 0.070s 0.070s update_record (zeroconf/__init__.py:1469) 0.00% 0.00% 0.060s 0.070s __eq__ (zeroconf/__init__.py:606) 0.00% 0.00% 0.050s 0.050s acquire (logging/__init__.py:843) 0.00% 0.00% 0.050s 0.050s unpack (zeroconf/__init__.py:722) 0.00% 0.00% 0.050s 0.050s read_name (zeroconf/__init__.py:828) 0.00% 0.00% 0.050s 0.050s is_expired (zeroconf/__init__.py:494) 0.00% 0.00% 0.040s 0.040s emit (logging/__init__.py:1028) 1.00% 1.00% 0.040s 0.040s __init__ (zeroconf/__init__.py:386) 0.00% 0.00% 0.040s 0.040s __enter__ (threading.py:241) ``` --- zeroconf/__init__.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 34329c67..2b23c971 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2444,12 +2444,21 @@ def handle_response(self, msg: DNSIncoming) -> None: updated = True if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - for entry in self.cache.entries(): + # Since the cache format is keyed on the lower case record name + # we can avoid iterating everything in the cache and + # only look though entries for the specific name. + # entries_with_name will take care of converting to lowercase + # + # We make a copy of the list that entries_with_name returns + # since we cannot iterate over something we might remove + for entry in self.cache.entries_with_name(record.name).copy(): if entry == record: updated = False - if DNSEntry.__eq__(entry, record) and (record.created - entry.created > 1000): + # Check the time first because it is far cheaper + # than the __eq__ + if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): self.cache.remove(entry) expired = record.is_expired(now) From eac53f45bddb8d3d559b1d4672a926b746435771 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 7 Mar 2020 22:03:34 +0000 Subject: [PATCH 008/608] Resolve memory leak in DNSCache When all the records for a given name were removed from the cache, the name itself that contain the list was never removed. This left an empty list in memory for every device that was no longer broadcasting on the network. --- zeroconf/__init__.py | 5 +++++ zeroconf/test.py | 11 +++++++++++ 2 files changed, 16 insertions(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 2b23c971..1d1cb3a2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1136,6 +1136,11 @@ def remove(self, entry: DNSRecord) -> None: try: list_ = self.cache[entry.key] list_.remove(entry) + # If we remove the last entry in the list + # we remove the key from the dict in order + # to avoid leaking memory + if not list_: + del self.cache[entry.key] except (KeyError, ValueError): pass diff --git a/zeroconf/test.py b/zeroconf/test.py index fb87524b..a371ad89 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -828,6 +828,17 @@ def test_order(self): cached_record = cache.get(entry) self.assertEqual(cached_record, record2) + def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): + record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache + class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): From aba28583f5431f584587770b6c149e4a607a987e Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 8 Mar 2020 00:39:22 +0100 Subject: [PATCH 009/608] Release version 0.24.5 --- README.rst | 13 +++++++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 432d3c4b..8c03ce6b 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,19 @@ See examples directory for more. Changelog ========= +0.24.5 +------ + +* Fixed issues with shared records being used where they shouldn't be (TXT, SRV, A records are + unique now), thanks to Matt Saxon +* Stopped unnecessarily excluding host-only interfaces from InterfaceChoice.all as they don't + forbid multicast, thanks to Andreas Oberritter +* Fixed repr() of IPv6 DNSAddress, thanks to Aldo Hoeben +* Removed duplicate update messages sent to listeners, thanks to Matt Saxon +* Added support for cooperating responders, thanks to Matt Saxon +* Optimized handle_response cache check, thanks to J. Nick Koston +* Fixed memory leak in DNSCache, thanks to J. Nick Koston + 0.24.4 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1d1cb3a2..2b2cccb9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.24.4' +__version__ = '0.24.5' __license__ = 'LGPL' From 8e3adf8300a6f2b0bc0dcc4cde54d8890e0727e9 Mon Sep 17 00:00:00 2001 From: Andreas Oberritter Date: Sun, 8 Mar 2020 00:59:46 +0100 Subject: [PATCH 010/608] Rationalize handling of values in TXT records * Do not interpret received values; use None if a property has no value * When encoding values, use either raw bytes or UTF-8 --- zeroconf/__init__.py | 29 ++++++++--------------------- zeroconf/test.py | 10 +++++----- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 2b2cccb9..41714f14 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1646,20 +1646,12 @@ def _set_properties(self, properties: Union[bytes, Dict]) -> None: if isinstance(key, str): key = key.encode('utf-8') - if value is None: - suffix = b'' - elif isinstance(value, str): - suffix = value.encode('utf-8') - elif isinstance(value, bytes): - suffix = value - elif isinstance(value, int): - if value: - suffix = b'true' - else: - suffix = b'false' - else: - suffix = b'' - list_.append(b'='.join((key, suffix))) + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) for item in list_: result = b''.join((result, int2byte(len(item)), item)) self.text = result @@ -1682,16 +1674,11 @@ def _set_text(self, text: bytes) -> None: for s in strs: parts = s.split(b'=', 1) try: - key, value = parts # type: Tuple[bytes, Union[bool, bytes]] + key, value = parts # type: Tuple[bytes, Optional[bytes]] except ValueError: # No equals sign at all key = s - value = False - else: - if value == b'true': - value = True - elif value == b'false' or not value: - value = False + value = None # Only update non-existent properties if key and result.get(key) is None: diff --git a/zeroconf/test.py b/zeroconf/test.py index a371ad89..2225ac1d 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1009,19 +1009,19 @@ def update_service(self, zeroconf, type, name): # get service info without answer cache info = zeroconf_browser.get_service_info(type_, registration_name) assert info is not None - assert info.properties[b'prop_none'] is False + assert info.properties[b'prop_none'] is None assert info.properties[b'prop_string'] == properties['prop_string'] - assert info.properties[b'prop_float'] is False + assert info.properties[b'prop_float'] == b'1.0' assert info.properties[b'prop_blank'] == properties['prop_blank'] - assert info.properties[b'prop_true'] is True - assert info.properties[b'prop_false'] is False + assert info.properties[b'prop_true'] == b'1' + assert info.properties[b'prop_false'] == b'0' assert info.addresses == addresses[:1] # no V6 by default all_addresses = info.addresses_by_version(r.IPVersion.All) assert all_addresses == addresses, all_addresses info = zeroconf_browser.get_service_info(subtype, registration_name) assert info is not None - assert info.properties[b'prop_none'] is False + assert info.properties[b'prop_none'] is None # test TXT record update sublistener = MySubListener() From a79015e7c4bdc843d97bd5c82ef8ed4eeae01a34 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 3 Apr 2020 11:28:36 +0200 Subject: [PATCH 011/608] Remove uniqueness assertions The assertions, added in [1] and modified in [2] introduced a regression. When browsing in the presence of devices advertising SRV records not marked as unique there would be an undesired crash (from [3]): Exception in thread zeroconf-ServiceBrowser__hap._tcp.local.: Traceback (most recent call last): File "/usr/lib/python3.7/threading.py", line 917, in _bootstrap_inner self.run() File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 1504, in run handler(self.zc) File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 1444, in zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 1322, in fire h(**kwargs) File "browser.py", line 20, in on_service_state_change info = zeroconf.get_service_info(service_type, name) File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 2191, in get_service_info if info.request(self, timeout): File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 1762, in request out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN), now) File "/home/pi/homekit-debugging/venv/lib/python3.7/site-packages/zeroconf/__init__.py", line 907, in add_answer_at_time assert record.unique AssertionError The intention is to bring those assertions back in a way that only enforces uniqueness when sending records, not when receiving them. [1] bef8f593ae82 ("Ensure all TXT, SRV, A records are unique") [2] 5e4f496778d9 ("Refactor out unique assertion") [3] https://github.com/jstasiak/python-zeroconf/issues/236 --- zeroconf/__init__.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 41714f14..e269443a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -902,10 +902,6 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: """Adds an answer if it does not expire by a certain time""" if record is not None: - - if self.is_type_unique(record.type): - assert record.unique - if now == 0 or not record.is_expired(now): self.answers.append((record, now)) @@ -949,9 +945,6 @@ def add_additional_answer(self, record: DNSRecord) -> None: o All address records (type "A" and "AAAA") named in the SRV rdata. """ - if self.is_type_unique(record.type): - assert record.unique - self.additionals.append(record) def pack(self, format_: Union[bytes, str], value: Any) -> None: From e839c40081ba15e228d447969b725ee42f1ef2ad Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 3 Apr 2020 12:56:50 +0200 Subject: [PATCH 012/608] Improve ServiceInfo documentation --- zeroconf/__init__.py | 47 ++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e269443a..f49de1b4 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1497,9 +1497,28 @@ def run(self) -> None: class ServiceInfo(RecordUpdateListener): - text = b'' + """Service information. + + Constructor parameters are as follows: + + * type_: fully qualified service type name + * name: fully qualified service name + * address: IP address as unsigned short, network byte order (deprecated, use addresses) + * port: port that the service runs on + * weight: weight of the service + * priority: priority of the service + * properties: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * server: fully qualified name for service host (defaults to name) + * host_ttl: ttl used for A/SRV records + * other_ttl: ttl used for PTR/TXT records + * addresses: List of IP addresses as unsigned short (IPv4) or unsigned 128 bit number (IPv6), + network byte order + + """ - """Service information""" + text = b'' # FIXME(dtantsur): black 19.3b0 produces code that is not valid syntax on # Python 3.5: https://github.com/python/black/issues/759 @@ -1519,23 +1538,6 @@ def __init__( *, addresses: Optional[List[bytes]] = None ) -> None: - """Create a service description. - - type_: fully qualified service type name - name: fully qualified service name - address: IP address as unsigned short, network byte order (deprecated, use addresses) - port: port that the service runs on - weight: weight of the service - priority: priority of the service - properties: dictionary of properties (or a string holding the - bytes for the text field) - server: fully qualified name for service host (defaults to name) - host_ttl: ttl used for A/SRV records - other_ttl: ttl used for PTR/TXT records - addresses: List of IP addresses as unsigned short (IPv4) or unsigned - 128 bit number (IPv6), network byte order - """ - # Accept both none, or one, but not both. if address is not None and addresses is not None: raise TypeError("address and addresses cannot be provided together") @@ -1610,6 +1612,13 @@ def addresses(self, value: List[bytes]) -> None: @property def properties(self) -> Dict: + """If properties were set in the constructor this property returns the original dictionary + of type `Dict[Union[bytes, str], Any]`. + + If properties are coming from the network, after decoding a TXT record, the keys are always + bytes and the values are either bytes, if there was a value, even empty, or `None`, if there + was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. + """ return self._properties def addresses_by_version(self, version: IPVersion) -> List[bytes]: From 0cbced809989283893e02914e251a94739a41062 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 3 Apr 2020 12:59:35 +0200 Subject: [PATCH 013/608] Release version 0.25.0 --- README.rst | 12 ++++++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 8c03ce6b..a5471f53 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,18 @@ See examples directory for more. Changelog ========= +0.25.0 +------ + +* Reverted uniqueness assertions when browsing, they caused a regression + +Backwards incompatible: + +* Rationalized handling of TXT records. Non-bytes values are converted to str and encoded to bytes + using UTF-8 now, None values mean value-less attributes. When receiving TXT records no decoding + is performed now, keys are always bytes and values are either bytes or None in value-less + attributes. + 0.24.5 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f49de1b4..cf0e012e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.24.5' +__version__ = '0.25.0' __license__ = 'LGPL' From f071f3d49d82ab212b86f889532200c94b36aea6 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 9 Apr 2020 13:00:16 +0200 Subject: [PATCH 014/608] Switch to pytest for test running (#240) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Nose is dead for all intents and purposes (last release in 2015) and pytest provide a very valuable feature of printing relevant extra information in case of assertion failure (from[1]): ================================= FAILURES ================================= _______________________________ test_answer ________________________________ def test_answer(): > assert func(3) == 5 E assert 4 == 5 E + where 4 = func(3) test_sample.py:6: AssertionError ========================= short test summary info ========================== FAILED test_sample.py::test_answer - assert 4 == 5 ============================ 1 failed in 0.12s ============================= This should be helpful in debugging tests intermittently failing on PyPy. Several TestCase.assertEqual() calls have been replaced by plain assertions now that that method no longer provides anything we can't get without it. Few assertions have been modified to not explicitly provide extra information in case of failure – pytest will provide this automatically. Dev dependencies are forced to be the latest versions to make sure we don't fail because of outdated ones on Travis. [1] https://docs.pytest.org/en/latest/getting-started.html#create-your-first-test --- .travis.yml | 4 ++-- Makefile | 5 ++-- requirements-dev.txt | 3 ++- zeroconf/test.py | 54 +++++++++++++++++++++----------------------- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/.travis.yml b/.travis.yml index acf47981..c5369538 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,13 +7,13 @@ python: - "pypy3.5" - "pypy3" install: - - pip install -r requirements-dev.txt + - pip install --upgrade -r requirements-dev.txt # mypy can't be installed on pypy - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install mypy ; fi - if [[ "${TRAVIS_PYTHON_VERSION}" != *"3.5"* && "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install black ; fi script: # no IPv6 support in Travis :( - - make TEST_ARGS='-a "!IPv6"' ci + - SKIP_IPV6=1 make ci after_success: - coveralls diff --git a/Makefile b/Makefile index ea5f8c64..af951f26 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,6 @@ MAX_LINE_LENGTH=110 PYTHON_IMPLEMENTATION:=$(shell python -c "import sys;import platform;sys.stdout.write(platform.python_implementation())") PYTHON_VERSION:=$(shell python -c "import sys;sys.stdout.write('%d.%d' % sys.version_info[:2])") -TEST_ARGS= LINT_TARGETS:=flake8 @@ -40,10 +39,10 @@ mypy: mypy examples/*.py zeroconf/*.py test: - nosetests -v $(TEST_ARGS) + pytest -v zeroconf/test.py test_coverage: - nosetests -v --with-coverage --cover-package=zeroconf $(TEST_ARGS) + pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf diff --git a/requirements-dev.txt b/requirements-dev.txt index ec443c0b..127df74e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,5 +5,6 @@ coverage flake8>=3.6.0 flake8-import-order ifaddr -nose pep8-naming!=0.6.0 +pytest +pytest-cov diff --git a/zeroconf/test.py b/zeroconf/test.py index 2225ac1d..802a5f11 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -6,6 +6,7 @@ import copy import logging +import os import socket import struct import time @@ -14,8 +15,6 @@ from typing import Dict, Optional # noqa # used in type hints from typing import cast -from nose.plugins.attrib import attr - import zeroconf as r from zeroconf import ( DNSHinfo, @@ -169,17 +168,17 @@ def test_parse_own_packet_response(self): 0, ) parsed = r.DNSIncoming(generated.packet()) - self.assertEqual(len(generated.answers), 1) - self.assertEqual(len(generated.answers), len(parsed.answers)) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) def test_match_question(self): generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packet()) - self.assertEqual(len(generated.questions), 1) - self.assertEqual(len(generated.questions), len(parsed.questions)) - self.assertEqual(question, parsed.questions[0]) + assert len(generated.questions) == 1 + assert len(generated.questions) == len(parsed.questions) + assert question == parsed.questions[0] def test_suppress_answer(self): query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) @@ -253,8 +252,8 @@ def test_dns_hinfo(self): generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packet()) answer = cast(r.DNSHinfo, parsed.answers[0]) - self.assertEqual(answer.cpu, u'cpu') - self.assertEqual(answer.os, u'os') + assert answer.cpu == u'cpu' + assert answer.os == u'os' generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) @@ -267,28 +266,28 @@ def test_transaction_id(self): generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) bytes = generated.packet() id = bytes[0] << 8 | bytes[1] - self.assertEqual(id, 0) + assert id == 0 def test_query_header_bits(self): generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) bytes = generated.packet() flags = bytes[2] << 8 | bytes[3] - self.assertEqual(flags, 0x0) + assert flags == 0x0 def test_response_header_bits(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) bytes = generated.packet() flags = bytes[2] << 8 | bytes[3] - self.assertEqual(flags, 0x8000) + assert flags == 0x8000 def test_numbers(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - self.assertEqual(num_questions, 0) - self.assertEqual(num_answers, 0) - self.assertEqual(num_authorities, 0) - self.assertEqual(num_additionals, 0) + assert num_questions == 0 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 def test_numbers_questions(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) @@ -297,10 +296,10 @@ def test_numbers_questions(self): generated.add_question(question) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - self.assertEqual(num_questions, 10) - self.assertEqual(num_answers, 0) - self.assertEqual(num_authorities, 0) - self.assertEqual(num_additionals, 0) + assert num_questions == 10 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 class Names(unittest.TestCase): @@ -502,7 +501,7 @@ def test_launch_and_close(self): rv.close() @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_launch_and_close_v4_v6(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) rv.close() @@ -510,7 +509,7 @@ def test_launch_and_close_v4_v6(self): rv.close() @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_launch_and_close_v6_only(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) rv.close() @@ -826,7 +825,7 @@ def test_order(self): cache.add(record2) entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN) cached_record = cache.get(entry) - self.assertEqual(cached_record, record2) + assert cached_record == record2 def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') @@ -888,7 +887,7 @@ def test_integration_with_listener_v6_records(self): zeroconf_registrar.close() @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') - @attr('IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_ipv6(self): type_ = "_test-srvc-type._tcp.local." @@ -904,9 +903,9 @@ def test_integration_with_listener_ipv6(self): try: service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) - assert type_ in service_types, service_types + assert type_ in service_types service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types, service_types + assert type_ in service_types finally: zeroconf_registrar.close() @@ -1016,8 +1015,7 @@ def update_service(self, zeroconf, type, name): assert info.properties[b'prop_true'] == b'1' assert info.properties[b'prop_false'] == b'0' assert info.addresses == addresses[:1] # no V6 by default - all_addresses = info.addresses_by_version(r.IPVersion.All) - assert all_addresses == addresses, all_addresses + assert info.addresses_by_version(r.IPVersion.All) == addresses info = zeroconf_browser.get_service_info(subtype, registration_name) assert info is not None From cf0382ba771bcc22284fd719c80a26eaa05ba5cd Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Fri, 10 Apr 2020 19:59:41 +0100 Subject: [PATCH 015/608] Remove unstable IPv6 tests from Travis (#241) --- zeroconf/test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 802a5f11..13df725c 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -863,6 +863,7 @@ def test_integration_with_listener(self): zeroconf_registrar.close() @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_v6_records(self): type_ = "_test-srvc-type._tcp.local." @@ -987,7 +988,7 @@ def update_service(self, zeroconf, type, name): desc = {'path': '/~paulsm/'} # type: Dict desc.update(properties) addresses = [socket.inet_aton("10.0.1.2")] - if socket.has_ipv6: + if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) info_service = ServiceInfo( subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses @@ -1357,7 +1358,7 @@ def test_multiple_addresses(): assert info.addresses == [address, address] - if socket.has_ipv6: + if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.") From 976e3dcf9d6d897b063ab6f0b7831bcfa6ac1814 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 14 Apr 2020 20:07:42 +0200 Subject: [PATCH 016/608] Update Engine to immediately notify its worker thread (#243) --- zeroconf/__init__.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index cf0e012e..a148dd75 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1203,12 +1203,13 @@ def __init__(self, zc: 'Zeroconf') -> None: self.readers = {} # type: Dict[socket.socket, Listener] self.timeout = 5 self.condition = threading.Condition() + self.socketpair = socket.socketpair() self.start() def run(self) -> None: while not self.zc.done: with self.condition: - rs = self.readers.keys() + rs = list(self.readers.keys()) if len(rs) == 0: # No sockets to manage, but we wait for the timeout # or addition of a socket @@ -1216,6 +1217,7 @@ def run(self) -> None: if len(rs) != 0: try: + rs = rs + [self.socketpair[0]] rr, wr, er = select.select(cast(Sequence[Any], rs), [], [], self.timeout) if not self.zc.done: for socket_ in rr: @@ -1223,21 +1225,36 @@ def run(self) -> None: if reader: reader.handle_read(socket_) + if self.socketpair[0] in rr: + # Clear the socket's buffer + self.socketpair[0].recv(128) + except (select.error, socket.error) as e: # If the socket was closed by another thread, during # shutdown, ignore it and exit if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: raise + self.socketpair[0].close() + self.socketpair[1].close() + + def _notify(self) -> None: + self.condition.notify() + try: + self.socketpair[1].send(b'x') + except socket.error: + # The socketpair may already be closed during shutdown, ignore it + if not self.zc.done: + raise def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None: with self.condition: self.readers[socket_] = reader - self.condition.notify() + self._notify() def del_reader(self, socket_: socket.socket) -> None: with self.condition: del self.readers[socket_] - self.condition.notify() + self._notify() class Listener(QuietLogger): From f8fe400e4be833728f015a3d6396bfc3f7c185c0 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 14 Apr 2020 21:01:53 +0200 Subject: [PATCH 017/608] Release version 0.25.1 --- README.rst | 5 +++++ zeroconf/__init__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index a5471f53..e24ddff0 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,11 @@ See examples directory for more. Changelog ========= +0.25.1 +------ + +* Eliminated 5s hangup when calling Zeroconf.close(), thanks to Erik Montnemery + 0.25.0 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a148dd75..4f136eec 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.25.0' +__version__ = '0.25.1' __license__ = 'LGPL' From 552a030eb592a0c07feaa7a01ece1464da4b1d0b Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Sun, 5 Apr 2020 20:03:23 +0100 Subject: [PATCH 018/608] Call UpdateService on SRV & A/AAAA updates as well as TXT (#239) Fix https://github.com/jstasiak/python-zeroconf/issues/235 Contains: * Add lock around handlers list * Reverse DNSCache order to ensure newest records take precedence When there are multiple records in the cache, the behaviour was inconsistent. Whilst the DNSCache.get() method returned the newest, any function which iterated over the entire cache suffered from a last write winds issue. This change makes this behaviour consistent and allows the removal of an (incorrect) wait from one of the unit tests. --- zeroconf/__init__.py | 132 ++++++++++++++++++++++++++++--------------- zeroconf/test.py | 102 +++++++++++++++++++++------------ 2 files changed, 150 insertions(+), 84 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 4f136eec..3949aa96 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -35,6 +35,7 @@ import threading import time import warnings +from collections import OrderedDict from typing import Dict, List, Optional, Sequence, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints @@ -1121,8 +1122,9 @@ def __init__(self) -> None: def add(self, entry: DNSRecord) -> None: """Adds an entry""" - # Insert first in list so get returns newest entry - self.cache.setdefault(entry.key, []).insert(0, entry) + # Insert last in list, get will return newest entry + # iteration will result in last update winning + self.cache.setdefault(entry.key, []).append(entry) def remove(self, entry: DNSRecord) -> None: """Removes an entry""" @@ -1142,7 +1144,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: matching entry.""" try: list_ = self.cache[entry.key] - for cached_entry in list_: + for cached_entry in reversed(list_): if entry.__eq__(cached_entry): return cached_entry return None @@ -1164,7 +1166,7 @@ def entries_with_name(self, name: str) -> List[DNSRecord]: def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() - for record in self.entries_with_name(name): + for record in reversed(self.entries_with_name(name)): if ( record.type == _TYPE_PTR and not record.is_expired(now) @@ -1400,7 +1402,7 @@ def __init__( self.services = {} # type: Dict[str, DNSRecord] self.next_time = current_time_millis() self.delay = delay - self._handlers_to_call = [] # type: List[Callable[[Zeroconf], None]] + self._handlers_to_call = OrderedDict() # type: OrderedDict[str, ServiceStateChange] self._service_state_changed = Signal() @@ -1445,14 +1447,30 @@ def service_state_changed(self) -> SignalRegistrationInterface: def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: """Callback invoked by Zeroconf when new information arrives. - Updates information required by browser in the Zeroconf cache.""" + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unecessary duplicates in the list + + """ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: - self._handlers_to_call.append( - lambda zeroconf: self._service_state_changed.fire( - zeroconf=zeroconf, service_type=self.type, name=name, state_change=state_change + + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + + if ( + state_change is ServiceStateChange.Added + or ( + state_change is ServiceStateChange.Removed + and ( + self._handlers_to_call.get(name) is ServiceStateChange.Updated + or self._handlers_to_call.get(name) is ServiceStateChange.Added + or self._handlers_to_call.get(name) is None + ) ) - ) + or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call) + ): + self._handlers_to_call[name] = state_change if record.type == _TYPE_PTR and record.name == self.type: assert isinstance(record, DNSPointer) @@ -1476,8 +1494,20 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: if expires < self.next_time: self.next_time = expires - elif record.type == _TYPE_TXT and record.name.endswith(self.type): - assert isinstance(record, DNSText) + elif record.type == _TYPE_A or record.type == _TYPE_AAAA: + assert isinstance(record, DNSAddress) + + # Iterate through the DNSCache and callback any services that use this address + for service in zc.cache.entries(): + if ( + isinstance(service, DNSService) + and service.name.endswith(self.type) + and service.server == record.name + and not record.is_expired(now) + ): + enqueue_callback(ServiceStateChange.Updated, service.name) + + elif record.name.endswith(self.type): expired = record.is_expired(now) if not expired: enqueue_callback(ServiceStateChange.Updated, record.name) @@ -1509,8 +1539,11 @@ def run(self) -> None: self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2) if len(self._handlers_to_call) > 0 and not self.zc.done: - handler = self._handlers_to_call.pop(0) - handler(self.zc) + with self.zc._handlers_lock: + handler = self._handlers_to_call.popitem(False) + self._service_state_changed.fire( + zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1] + ) class ServiceInfo(RecordUpdateListener): @@ -2173,6 +2206,8 @@ def __init__( self.debug = None # type: Optional[DNSOutgoing] + self._handlers_lock = threading.Lock() # ensure we process a full message in one go + @property def done(self) -> bool: return self._GLOBAL_DONE @@ -2449,42 +2484,45 @@ def update_record(self, now: float, rec: DNSRecord) -> None: def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" - now = current_time_millis() - for record in msg.answers: - - updated = True - - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # Since the cache format is keyed on the lower case record name - # we can avoid iterating everything in the cache and - # only look though entries for the specific name. - # entries_with_name will take care of converting to lowercase - # - # We make a copy of the list that entries_with_name returns - # since we cannot iterate over something we might remove - for entry in self.cache.entries_with_name(record.name).copy(): - if entry == record: - updated = False + with self._handlers_lock: - # Check the time first because it is far cheaper - # than the __eq__ - if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): - self.cache.remove(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) + now = current_time_millis() + for record in msg.answers: + + updated = True + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + # Since the cache format is keyed on the lower case record name + # we can avoid iterating everything in the cache and + # only look though entries for the specific name. + # entries_with_name will take care of converting to lowercase + # + # We make a copy of the list that entries_with_name returns + # since we cannot iterate over something we might remove + for entry in self.cache.entries_with_name(record.name).copy(): + + if entry == record: + updated = False + + # Check the time first because it is far cheaper + # than the __eq__ + if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): + self.cache.remove(entry) + + expired = record.is_expired(now) + maybe_entry = self.cache.get(record) + if not expired: + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + self.cache.add(record) + if updated: + self.update_record(now, record) else: - self.cache.add(record) - if updated: - self.update_record(now, record) - else: - if maybe_entry is not None: - self.update_record(now, record) - self.cache.remove(maybe_entry) + if maybe_entry is not None: + self.update_record(now, record) + self.cache.remove(maybe_entry) def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if diff --git a/zeroconf/test.py b/zeroconf/test.py index 13df725c..dd6fc21d 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -292,7 +292,7 @@ def test_numbers(self): def test_numbers_questions(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - for i in range(10): + for i in range(10): # pylint: disable=unused-variable generated.add_question(question) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) @@ -756,7 +756,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): """Sends an outgoing packet.""" nonlocal nbr_answers, nbr_additionals, nbr_authorities - for answer, time_ in out.answers: + for answer, time_ in out.answers: # pylint: disable=unused-variable nbr_answers += 1 assert answer.ttl == get_ttl(answer.type) for answer in out.additionals: @@ -1053,12 +1053,12 @@ def test_update_record(self): service_name = 'name._type._tcp.local.' service_type = '_type._tcp.local.' - service_server = 'ash-2.local.' - service_text = b'path=/~paulsm/' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' service_address = '10.0.1.2' - service_added = False - service_removed = False + service_added_count = 0 + service_removed_count = 0 service_updated_count = 0 service_add_event = Event() service_removed_event = Event() @@ -1066,49 +1066,44 @@ def test_update_record(self): class MyServiceListener(r.ServiceListener): def add_service(self, zc, type_, name) -> None: - nonlocal service_added - service_added = True + nonlocal service_added_count + service_added_count += 1 service_add_event.set() def remove_service(self, zc, type_, name) -> None: - nonlocal service_added, service_removed - service_added = False - service_removed = True + nonlocal service_removed_count + service_removed_count += 1 service_removed_event.set() def update_service(self, zc, type_, name) -> None: nonlocal service_updated_count service_updated_count += 1 - service_info = zc.get_service_info(type_, name) + assert service_info.addresses[0] == socket.inet_aton(service_address) assert service_info.text == service_text + assert service_info.server == service_server service_updated_event.set() def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - if service_state_change == r.ServiceStateChange.Updated: - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - return r.DNSIncoming(generated.packet()) + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) if service_state_change == r.ServiceStateChange.Removed: ttl = 0 + else: + ttl = 120 generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 ) + generated.add_answer_at_time( r.DNSService( service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server ), 0, ) - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) + generated.add_answer_at_time( r.DNSAddress( service_server, @@ -1120,36 +1115,69 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi 0, ) + generated.add_answer_at_time( + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + ) + return r.DNSIncoming(generated.packet()) zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) try: + wait_time = 3 + # service added zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added)) - service_add_event.wait(1) - service_updated_event.wait(1) - assert service_added is True - assert service_updated_count == 1 - assert service_removed is False + service_add_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 0 + assert service_removed_count == 0 - # service updated. currently only text record can be updated + # service SRV updated service_updated_event.clear() - service_text = b'path=/~humingchun/' + service_server = 'ash-2.local.' zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 1 + assert service_removed_count == 0 + + # service TXT updated + service_updated_event.clear() + service_text = b'path=/~matt2/' zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(1) - assert service_added is True + service_updated_event.wait(wait_time) + assert service_added_count == 1 assert service_updated_count == 2 - assert service_removed is False + assert service_removed_count == 0 + + # service A updated + service_updated_event.clear() + service_address = '10.0.1.3' + zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 3 + assert service_removed_count == 0 + + # service all updated + service_updated_event.clear() + service_server = 'ash-3.local.' + service_text = b'path=/~matt3/' + service_address = '10.0.1.3' + zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 0 # service removed zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) - service_removed_event.wait(1) - assert service_added is False - assert service_updated_count == 2 - assert service_removed is True + service_removed_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 1 finally: service_browser.cancel() From 36941aeb72711f7954d40f0abeab4802174636df Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 26 Apr 2020 02:53:27 +0200 Subject: [PATCH 019/608] Release version 0.26.0 --- README.rst | 11 +++++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index e24ddff0..b23cee06 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,17 @@ See examples directory for more. Changelog ========= +0.26.0 +------ + +* Fixed a regression where service update listener wasn't called on IP address change (it's called + on SRV/A/AAAA record changes now), thanks to Matt Saxon + +Technically backwards incompatible: + +* Service update hook is no longer called on service addition (service added hook is still called), + this is related to the fix above + 0.25.1 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3949aa96..feac8cbd 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.25.1' +__version__ = '0.26.0' __license__ = 'LGPL' From 16431b6cb51f561a4c5d2897e662b254ca4243ec Mon Sep 17 00:00:00 2001 From: mattsaxon Date: Sun, 26 Apr 2020 11:22:30 +0100 Subject: [PATCH 020/608] Update .gitignore for Visual Studio config files (#244) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ddf8a0d7..eac2c170 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ Thumbs.db .cache .mypy_cache/ docs/_build/ +.vscode From 0540342bacd859f38f6d2a3743a7959cd3ae4d02 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 6 May 2020 11:12:56 -0500 Subject: [PATCH 021/608] Avoid iterating the entire cache when an A/AAAA address has not changed (#247) Iterating the cache is an expensive operation when there is 100s of devices generating zeroconf traffic as there can be 1000s of entries in the cache. --- zeroconf/__init__.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index feac8cbd..a4ac31ac 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1496,6 +1496,20 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: elif record.type == _TYPE_A or record.type == _TYPE_AAAA: assert isinstance(record, DNSAddress) + if record.is_expired(now): + return + + address_changed = False + for service in zc.cache.entries_with_name(record.name): + if isinstance(service, DNSAddress) and service.address != record.address: + address_changed = True + break + + # Avoid iterating the entire DNSCache if the address has not changed + # as this is an expensive operation when there many hosts + # generating zeroconf traffic. + if not address_changed: + return # Iterate through the DNSCache and callback any services that use this address for service in zc.cache.entries(): @@ -1503,7 +1517,6 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: isinstance(service, DNSService) and service.name.endswith(self.type) and service.server == record.name - and not record.is_expired(now) ): enqueue_callback(ServiceStateChange.Updated, service.name) From 0dd6fe44ca3895375ba447fed5f138042ab12ebf Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 6 May 2020 18:17:29 +0200 Subject: [PATCH 022/608] Remove unwanted pylint directives Those are results of a bad conflict resolution I did when merging [1]. [1] 552a030eb592 ("Call UpdateService on SRV & A/AAAA updates as well as TXT (#239)") --- zeroconf/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index dd6fc21d..4e54c1d1 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -292,7 +292,7 @@ def test_numbers(self): def test_numbers_questions(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - for i in range(10): # pylint: disable=unused-variable + for i in range(10): generated.add_question(question) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) @@ -756,7 +756,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): """Sends an outgoing packet.""" nonlocal nbr_answers, nbr_additionals, nbr_authorities - for answer, time_ in out.answers: # pylint: disable=unused-variable + for answer, time_ in out.answers: nbr_answers += 1 assert answer.ttl == get_ttl(answer.type) for answer in out.additionals: From 4c359e2e7cdf104efca90ffd9912ea7c7792e3bf Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 6 May 2020 18:25:51 +0200 Subject: [PATCH 023/608] Release version 0.26.1 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b23cee06..f70aabb2 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.26.1 +------ + +* Fixed a performance regression introduced in 0.26.0, thanks to J. Nick Koston (this is close in + spirit to an optimization made in 0.24.5 by the same author) + 0.26.0 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a4ac31ac..1a83c579 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.26.0' +__version__ = '0.26.1' __license__ = 'LGPL' From 4b1d953979287e08f914857867da1000634ca3af Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 22 May 2020 13:55:19 -0500 Subject: [PATCH 024/608] Fix flake8 E741 in setup.py (#252) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a011e602..57b8b0be 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ readme = f.read() version = ( - [l for l in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in l][0] + [ln for ln in open(join(PROJECT_ROOT, 'zeroconf', '__init__.py')) if '__version__' in ln][0] .split('=')[-1] .strip() .strip('\'"') From 24a06191ea35469948d12124a07429207b3c1b3b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 22 May 2020 17:12:17 +0000 Subject: [PATCH 025/608] Fix race condition where a listener gets a message before the lock is created. --- zeroconf/__init__.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1a83c579..6e7befe4 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2208,6 +2208,11 @@ def __init__( self.condition = threading.Condition() + # Ensure we create the lock before + # we add the listener as we could get + # a message before the lock is created. + self._handlers_lock = threading.Lock() # ensure we process a full message in one go + self.engine = Engine(self) self.listener = Listener(self) if not unicast: @@ -2219,8 +2224,6 @@ def __init__( self.debug = None # type: Optional[DNSOutgoing] - self._handlers_lock = threading.Lock() # ensure we process a full message in one go - @property def done(self) -> bool: return self._GLOBAL_DONE From a6ad100a60e8434cef6b411208eef98f68d594d3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 13 May 2020 19:39:46 +0000 Subject: [PATCH 026/608] Add support for multiple types to ServiceBrowsers As each ServiceBrowser runs in its own thread there is a scale problem when listening for many types. ServiceBrowser can now accept a list of types in addition to a single type. --- examples/browser.py | 4 +- zeroconf/__init__.py | 89 +++++++++++++++++++++++++------------------- zeroconf/test.py | 73 ++++++++++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+), 40 deletions(-) diff --git a/examples/browser.py b/examples/browser.py index bf3ebfbd..c4ddac39 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -55,7 +55,9 @@ def on_service_state_change( zeroconf = Zeroconf(ip_version=ip_version) print("\nBrowsing services, press Ctrl-C to exit...\n") - browser = ServiceBrowser(zeroconf, "_http._tcp.local.", handlers=[on_service_state_change]) + browser = ServiceBrowser( + zeroconf, ["_http._tcp.local.", "_hap._tcp.local."], handlers=[on_service_state_change] + ) try: while True: diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 6e7befe4..4b93e453 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1379,7 +1379,7 @@ class ServiceBrowser(RecordUpdateListener, threading.Thread): def __init__( self, zc: 'Zeroconf', - type_: str, + type_: Union[str, list], # NOTE: Callable quoting needed on Python 3.5.2, see # https://github.com/jstasiak/python-zeroconf/issues/208 for details. handlers: Optional[Union[ServiceListener, List['Callable[..., None]']]] = None, @@ -1390,19 +1390,23 @@ def __init__( ) -> None: """Creates a browser for a specific type""" assert handlers or listener, 'You need to specify at least one handler' - if not type_.endswith(service_type_name(type_, allow_underscores=True)): - raise BadTypeInNameException - threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + type_) + self.types = set(type_ if isinstance(type_, list) else [type_]) + for check_type_ in self.types: + if not check_type_.endswith(service_type_name(check_type_, allow_underscores=True)): + raise BadTypeInNameException + threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + '-'.join(self.types)) self.daemon = True self.zc = zc - self.type = type_ self.addr = addr self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - self.services = {} # type: Dict[str, DNSRecord] - self.next_time = current_time_millis() - self.delay = delay - self._handlers_to_call = OrderedDict() # type: OrderedDict[str, ServiceStateChange] + self._services = { + check_type_: {} for check_type_ in self.types + } # type: Dict[str, Dict[str, DNSRecord]] + current_time = current_time_millis() + self._next_time = {check_type_: current_time for check_type_ in self.types} + self._delay = {check_type_: delay for check_type_ in self.types} + self._handlers_to_call = OrderedDict() # type: OrderedDict[str, Tuple[str, ServiceStateChange]] self._service_state_changed = Signal() @@ -1453,7 +1457,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: """ - def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: + def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> None: # Code to ensure we only do a single update message # Precedence is; Added, Remove, Update @@ -1470,29 +1474,29 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: ) or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call) ): - self._handlers_to_call[name] = state_change + self._handlers_to_call[name] = (type_, state_change) - if record.type == _TYPE_PTR and record.name == self.type: + if record.type == _TYPE_PTR and record.name in self.types: assert isinstance(record, DNSPointer) expired = record.is_expired(now) service_key = record.alias.lower() try: - old_record = self.services[service_key] + old_record = self._services[record.name][service_key] except KeyError: if not expired: - self.services[service_key] = record - enqueue_callback(ServiceStateChange.Added, record.alias) + self._services[record.name][service_key] = record + enqueue_callback(ServiceStateChange.Added, record.name, record.alias) else: if not expired: old_record.reset_ttl(record) else: - del self.services[service_key] - enqueue_callback(ServiceStateChange.Removed, record.alias) + del self._services[record.name][service_key] + enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) return expires = record.get_expiration_time(75) - if expires < self.next_time: - self.next_time = expires + if expires < self._next_time[record.name]: + self._next_time[record.name] = expires elif record.type == _TYPE_A or record.type == _TYPE_AAAA: assert isinstance(record, DNSAddress) @@ -1513,17 +1517,16 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None: # Iterate through the DNSCache and callback any services that use this address for service in zc.cache.entries(): - if ( - isinstance(service, DNSService) - and service.name.endswith(self.type) - and service.server == record.name - ): - enqueue_callback(ServiceStateChange.Updated, service.name) + if not isinstance(service, DNSService) or not service.server == record.name: + continue + for type_ in self.types: + if service.name.endswith(type_): + enqueue_callback(ServiceStateChange.Updated, type_, service.name) - elif record.name.endswith(self.type): - expired = record.is_expired(now) - if not expired: - enqueue_callback(ServiceStateChange.Updated, record.name) + elif not record.is_expired(now): + for type_ in self.types: + if record.name.endswith(type_): + enqueue_callback(ServiceStateChange.Updated, type_, record.name) def cancel(self) -> None: self.done = True @@ -1531,31 +1534,39 @@ def cancel(self) -> None: self.join() def run(self) -> None: - self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) + for type_ in self.types: + self.zc.add_listener(self, DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) while True: now = current_time_millis() - if len(self._handlers_to_call) == 0 and self.next_time > now: - self.zc.wait(self.next_time - now) + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + if len(self._handlers_to_call) == 0 and next_time > now: + self.zc.wait(next_time - now) if self.zc.done or self.done: return now = current_time_millis() - if self.next_time <= now: + for type_ in self.types: + if self._next_time[type_] > now: + continue out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) - out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN)) - for record in self.services.values(): + out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) + for record in self._services[type_].values(): if not record.is_stale(now): out.add_answer_at_time(record, now) self.zc.send(out, addr=self.addr, port=self.port) - self.next_time = now + self.delay - self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2) + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) if len(self._handlers_to_call) > 0 and not self.zc.done: with self.zc._handlers_lock: - handler = self._handlers_to_call.popitem(False) + (name, service_type_state_change) = self._handlers_to_call.popitem(False) self._service_state_changed.fire( - zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1] + zeroconf=self.zc, + service_type=service_type_state_change[0], + name=name, + state_change=service_type_state_change[1], ) diff --git a/zeroconf/test.py b/zeroconf/test.py index 4e54c1d1..c1993a58 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1185,6 +1185,79 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi zeroconf.close() +class TestServiceBrowserMultipleTypes(unittest.TestCase): + def test_update_record(self): + + service_names = ['name._type._tcp.local.', 'name._type._udp.local'] + service_types = ['_type._tcp.local.', '_type._udp.local.'] + + service_added_count = 0 + service_removed_count = 0 + service_add_event = Event() + service_removed_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + if service_added_count == 2: + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + if service_removed_count == 2: + service_removed_event.set() + + def mock_incoming_msg( + service_state_change: r.ServiceStateChange, service_type: str, service_name: str + ) -> r.DNSIncoming: + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + else: + ttl = 120 + + generated.add_answer_at_time( + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + ) + return r.DNSIncoming(generated.packet()) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) + + try: + wait_time = 3 + + # both services added + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0]) + ) + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1]) + ) + service_add_event.wait(wait_time) + assert service_added_count == 2 + assert service_removed_count == 0 + + # both services removed + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0]) + ) + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1]) + ) + service_removed_event.wait(wait_time) + assert service_added_count == 2 + assert service_removed_count == 2 + + finally: + service_browser.cancel() + zeroconf.remove_all_service_listeners() + zeroconf.close() + + def test_backoff(): got_query = Event() From aa9de4de7202b3ab0a60f14532d227f63d7d981b Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Sun, 24 May 2020 21:54:06 +0200 Subject: [PATCH 027/608] Improve readability of logged incoming data (#254) --- zeroconf/__init__.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 4b93e453..041dd6b2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -719,6 +719,20 @@ def __init__(self, data: bytes) -> None: except (IndexError, struct.error, IncomingDecodeError): self.log_exception_warning(('Choked at offset %d while unpacking %r', self.offset, data)) + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'n_q=%s' % self.num_questions, + 'n_ans=%s' % self.num_answers, + 'n_auth=%s' % self.num_authorities, + 'n_add=%s' % self.num_additionals, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + ] + ) + def unpack(self, format_: bytes) -> tuple: length = struct.calcsize(format_) info = struct.unpack(format_, self.data[self.offset : self.offset + length]) @@ -1279,10 +1293,13 @@ def handle_read(self, socket_: socket.socket) -> None: self.log_exception_warning() return - log.debug('Received from %r:%r: %r ', addr, port, data) - self.data = data msg = DNSIncoming(data) + if msg.valid: + log.debug('Received from %r:%r: %r (%d bytes) as [%r]', addr, port, msg, len(data), data) + else: + log.debug('Received from %r:%r: (%d bytes) [%r]', addr, port, len(data), data) + if not msg.valid: pass @@ -2695,7 +2712,7 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P if len(packet) > _MAX_MSG_ABSOLUTE: self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return - log.debug('Sending %r (%d bytes) as %r...', out, len(packet), packet) + log.debug('Sending %r (%d bytes) as [%r]', out, len(packet), packet) for s in self._respond_sockets: if self._GLOBAL_DONE: return From 1c4d3fcbf34b09364e52a773783dc9c924a7b17a Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 22 May 2020 21:41:42 +0200 Subject: [PATCH 028/608] Merge 0.26.2 release commit I accidentally only pushed 0.26.2 tag (commit ffb42e5836bd) without pushing the commit to master and now I merged aa9de4de7202 so this is the best I can do without force-pushing to master. Tag 0.26.2 will continue to point to that dangling commit. --- README.rst | 7 +++++++ zeroconf/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f70aabb2..c335e2c8 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,13 @@ See examples directory for more. Changelog ========= +0.26.2 +------ + +* Added support for multiple types to ServiceBrowser, thanks to J. Nick Koston +* Fixed a race condition where a listener gets a message before the lock is created, thanks to + J. Nick Koston + 0.26.1 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 041dd6b2..e21a7f39 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.26.1' +__version__ = '0.26.2' __license__ = 'LGPL' From 445d7f5dbe38947bd0bd1e3a5b8d649c1819c21f Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 24 May 2020 22:48:03 +0200 Subject: [PATCH 029/608] Use equality comparison instead of identity comparison for ints Integers aren't guaranteed to have the same identity even though they may be equal. --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e21a7f39..e2732a74 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2064,7 +2064,7 @@ def new_socket( if not err.errno == errno.ENOPROTOOPT: raise - if port is _MDNS_PORT: + if port == _MDNS_PORT: ttl = struct.pack(b'B', 255) loop = struct.pack(b'B', 1) if ip_version != IPVersion.V6Only: From 54d116fd69a66062f91be04d84ceaebcfb13cc43 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 25 May 2020 23:27:07 +0200 Subject: [PATCH 030/608] Give threads unique names (#257) --- zeroconf/__init__.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e2732a74..a7b41ca4 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1213,7 +1213,7 @@ class Engine(threading.Thread): """ def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self, name='zeroconf-Engine') + threading.Thread.__init__(self) self.daemon = True self.zc = zc self.readers = {} # type: Dict[socket.socket, Listener] @@ -1221,6 +1221,7 @@ def __init__(self, zc: 'Zeroconf') -> None: self.condition = threading.Condition() self.socketpair = socket.socketpair() self.start() + self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) def run(self) -> None: while not self.zc.done: @@ -1324,10 +1325,11 @@ class Reaper(threading.Thread): have expired.""" def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self, name='zeroconf-Reaper') + threading.Thread.__init__(self) self.daemon = True self.zc = zc self.start() + self.name = "zeroconf-Reaper_%s" % (getattr(self, 'native_id', self.ident),) def run(self) -> None: while True: @@ -1411,7 +1413,7 @@ def __init__( for check_type_ in self.types: if not check_type_.endswith(service_type_name(check_type_, allow_underscores=True)): raise BadTypeInNameException - threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + '-'.join(self.types)) + threading.Thread.__init__(self) self.daemon = True self.zc = zc self.addr = addr @@ -1460,6 +1462,10 @@ def on_change( self.service_state_changed.register_handler(h) self.start() + self.name = "zeroconf-ServiceBrowser_%s_%s" % ( + '-'.join(self.types), + getattr(self, 'native_id', self.ident), + ) @property def service_state_changed(self) -> SignalRegistrationInterface: From fe865667e4610d57067a8f710f4d818eaa5e14dc Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Tue, 26 May 2020 22:50:13 +0200 Subject: [PATCH 031/608] Don't call callbacks when holding _handlers_lock (#258) Closes #255 Background: #239 adds the lock _handlers_lock: python-zeroconf/zeroconf/__init__.py self._handlers_lock = threading.Lock() # ensure we process a full message in one go Which is used in the engine thread: def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" with self._handlers_lock: And also by the service browser when issuing the state change callbacks: if len(self._handlers_to_call) > 0 and not self.zc.done: with self.zc._handlers_lock: handler = self._handlers_to_call.popitem(False) self._service_state_changed.fire( zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1] ) Both pychromecast and Home Assistant calls Zeroconf.get_service_info from the service callbacks which means the lock may be held for several seconds which will starve the engine thread. --- zeroconf/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a7b41ca4..85a187ea 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1585,12 +1585,12 @@ def run(self) -> None: if len(self._handlers_to_call) > 0 and not self.zc.done: with self.zc._handlers_lock: (name, service_type_state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=service_type_state_change[0], - name=name, - state_change=service_type_state_change[1], - ) + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=service_type_state_change[0], + name=name, + state_change=service_type_state_change[1], + ) class ServiceInfo(RecordUpdateListener): From fbcefca592632304579c1b3f9c7bd3dd342e1618 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 26 May 2020 23:01:48 +0200 Subject: [PATCH 032/608] Release version 0.26.3 --- README.rst | 9 +++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index c335e2c8..84c45906 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,15 @@ See examples directory for more. Changelog ========= +0.26.3 +------ + +* Improved readability of logged incoming data, thanks to Erik Montnemery +* Threads are given unique names now to aid debugging, thanks to Erik Montnemery +* Fixed a regression where get_service_info() called within a listener add_service method + would deadlock, timeout and incorrectly return None, fix thanks to Erik Montnemery, but + Matt Saxon and Hmmbob were also involved in debugging it. + 0.26.2 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 85a187ea..76c5a431 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.26.2' +__version__ = '0.26.3' __license__ = 'LGPL' From ab72aa8e5a6a83e50d24d7fb187e8fa8a549a847 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 26 May 2020 23:05:27 +0200 Subject: [PATCH 033/608] Remove deprecated ServiceInfo address parameter/property (#260) --- zeroconf/__init__.py | 30 ------------ zeroconf/test.py | 111 ++++++++++++++++++++++++++----------------- 2 files changed, 68 insertions(+), 73 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 76c5a431..00bf1c6d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -34,7 +34,6 @@ import sys import threading import time -import warnings from collections import OrderedDict from typing import Dict, List, Optional, Sequence, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints @@ -1600,7 +1599,6 @@ class ServiceInfo(RecordUpdateListener): * type_: fully qualified service type name * name: fully qualified service name - * address: IP address as unsigned short, network byte order (deprecated, use addresses) * port: port that the service runs on * weight: weight of the service * priority: priority of the service @@ -1624,7 +1622,6 @@ def __init__( self, type_: str, name: str, - address: Optional[Union[bytes, List[bytes]]] = None, port: Optional[int] = None, weight: int = 0, priority: int = 0, @@ -1635,22 +1632,12 @@ def __init__( *, addresses: Optional[List[bytes]] = None ) -> None: - # Accept both none, or one, but not both. - if address is not None and addresses is not None: - raise TypeError("address and addresses cannot be provided together") - if not type_.endswith(service_type_name(name, allow_underscores=True)): raise BadTypeInNameException self.type = type_ self.name = name if addresses is not None: self._addresses = addresses - elif address is not None: - warnings.warn("address is deprecated, use addresses instead", DeprecationWarning) - if isinstance(address, list): - self._addresses = address - else: - self._addresses = [address] else: self._addresses = [] # This results in an ugly error when registering, better check now @@ -1672,23 +1659,6 @@ def __init__( self.other_ttl = other_ttl # fmt: on - @property - def address(self) -> Optional[bytes]: - warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) - try: - # Return the first V4 address for compatibility - return self.addresses[0] - except IndexError: - return None - - @address.setter - def address(self, value: bytes) -> None: - warnings.warn("ServiceInfo.address is deprecated, use addresses instead", DeprecationWarning) - if value is None: - self._addresses = [] - else: - self._addresses = [value] - @property def addresses(self) -> List[bytes]: """IPv4 addresses of this service. diff --git a/zeroconf/test.py b/zeroconf/test.py index c1993a58..2bc642c8 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -108,7 +108,7 @@ def test_service_info_dunder(self): name = "xxxyyy" registration_name = "%s.%s" % (name, type_) info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, b'', "ash-2.local." + type_, registration_name, 80, 0, 0, b'', "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) assert not info != info @@ -121,7 +121,7 @@ def test_service_info_text_properties_not_given(self): info = ServiceInfo( type_=type_, name=registration_name, - address=socket.inet_aton("10.0.1.2"), + addresses=[socket.inet_aton("10.0.1.2")], port=80, server="ash-2.local.", ) @@ -456,7 +456,14 @@ def on_service_state_change(zeroconf, service_type, state_change, name): def verify_name_change(self, zc, type_, name, number_hosts): desc = {'path': '/~paulsm/'} info_service = ServiceInfo( - type_, '%s.%s' % (name, type_), socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) # verify name conflict @@ -736,7 +743,14 @@ def test_ttl(self): desc = {'path': '/~paulsm/'} info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) # we are going to monkey patch the zeroconf send to check packet sizes @@ -849,7 +863,14 @@ def test_integration_with_listener(self): zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) desc = {'path': '/~paulsm/'} info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) @@ -874,7 +895,14 @@ def test_integration_with_listener_v6_records(self): zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) desc = {'path': '/~paulsm/'} info = ServiceInfo( - type_, registration_name, socket.inet_pton(socket.AF_INET6, addr), 80, 0, 0, desc, "ash-2.local." + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) zeroconf_registrar.register_service(info) @@ -898,7 +926,14 @@ def test_integration_with_listener_ipv6(self): zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) desc = {'path': '/~paulsm/'} info = ServiceInfo( - type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) @@ -922,7 +957,14 @@ def test_integration_with_subtype_and_listener(self): zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) desc = {'path': '/~paulsm/'} info = ServiceInfo( - discovery_type, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + discovery_type, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) @@ -1028,7 +1070,14 @@ def update_service(self, zeroconf, type, name): properties['prop_blank'] = b'an updated string' desc.update(properties) info_service = ServiceInfo( - subtype, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local." + subtype, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.update_service(info_service) service_updated.wait(1) @@ -1385,7 +1434,9 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) desc = {'path': '/~paulsm/'} - info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.") + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) zeroconf_registrar.register_service(info) try: @@ -1424,45 +1475,17 @@ def test_multiple_addresses(): address_parsed = "10.0.1.2" address = socket.inet_aton(address_parsed) - # Old way - info = ServiceInfo(type_, registration_name, address, 80, 0, 0, desc, "ash-2.local.") - - assert info.address == address - assert info.addresses == [address] - - # Updating works - address2 = socket.inet_aton("10.0.1.3") - info.address = address2 - - assert info.address == address2 - assert info.addresses == [address2] - - info.address = None - - assert info.address is None - assert info.addresses == [] - - info.addresses = [address2] - - assert info.address == address2 - assert info.addresses == [address2] - - # Compatibility way - info = ServiceInfo(type_, registration_name, [address, address], 80, 0, 0, desc, "ash-2.local.") - - assert info.addresses == [address, address] - # New kwarg way - info = ServiceInfo( - type_, registration_name, None, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address] - ) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) assert info.addresses == [address, address] if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) - info = ServiceInfo(type_, registration_name, [address, address_v6], 80, 0, 0, desc, "ash-2.local.") + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address_v6], + ) assert info.addresses == [address] assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] assert info.addresses_by_version(r.IPVersion.V4Only) == [address] @@ -1483,7 +1506,9 @@ def test_ptr_optimization(): registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} - info = ServiceInfo(type_, registration_name, socket.inet_aton("10.0.1.2"), 80, 0, 0, desc, "ash-2.local.") + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) # we are going to monkey patch the zeroconf send to check packet sizes old_send = zc.send From 87a0fe27a7be9d96af08f8a007f37a16105c64a0 Mon Sep 17 00:00:00 2001 From: gjbadros Date: Tue, 26 May 2020 14:07:47 -0700 Subject: [PATCH 034/608] Separately send large mDNS responses to comply with RFC 6762 (#248) This fixes issue #245 Split up large multi-response packets into separate packets instead of relying on IP Fragmentation. IP Fragmentation of mDNS packets causes ChromeCast Audios to crash their mDNS responder processes and RFC 6762 (https://tools.ietf.org/html/rfc6762) section 17 states some requirements for Multicast DNS Message Size, and the fourth paragraph reads: "A Multicast DNS packet larger than the interface MTU, which is sent using fragments, MUST NOT contain more than one resource record." This change makes this implementation conform with this MUST NOT clause. --- zeroconf/__init__.py | 178 ++++++++++++++++++++++++++++++------------- zeroconf/test.py | 34 +++++---- 2 files changed, 146 insertions(+), 66 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 00bf1c6d..9fa6ae5c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -874,9 +874,13 @@ def __init__(self, flags: int, multicast: bool = True) -> None: self.id = 0 self.multicast = multicast self.flags = flags + self.packets_data = [] # type: List[bytes] + + # these 3 are per-packet -- see also reset_for_next_packet() self.names = {} # type: Dict[str, int] self.data = [] # type: List[bytes] self.size = 12 + self.state = self.State.init self.questions = [] # type: List[DNSQuestion] @@ -884,6 +888,11 @@ def __init__(self, flags: int, multicast: bool = True) -> None: self.authorities = [] # type: List[DNSPointer] self.additionals = [] # type: List[DNSRecord] + def reset_for_next_packet(self) -> None: + self.names = {} + self.data = [] + self.size = 12 + def __repr__(self) -> str: return '' % ', '.join( [ @@ -1058,11 +1067,13 @@ def write_question(self, question: DNSQuestion) -> None: self.write_short(question.type) self.write_short(question.class_) - def write_record(self, record: DNSRecord, now: float) -> int: + def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) -> bool: """Writes a record (answer, authoritative answer, additional) to - the packet""" + the packet. Returns True on success, or False if we did not (either + because the packet was already finished or because the record does + not fit.""" if self.state == self.State.finished: - return 1 + return False start_data_length, start_size = len(self.data), self.size self.write_name(record.name) @@ -1086,44 +1097,102 @@ def write_record(self, record: DNSRecord, now: float) -> int: # Here is the short we adjusted for self.insert_short(index, length) + len_limit = _MAX_MSG_ABSOLUTE if allow_long else _MAX_MSG_TYPICAL + # if we go over, then rollback and quit - if self.size > _MAX_MSG_ABSOLUTE: + if self.size > len_limit: while len(self.data) > start_data_length: self.data.pop() self.size = start_size - self.state = self.State.finished - return 1 - return 0 + return False + return True def packet(self) -> bytes: - """Returns a string containing the packet's bytes + """Returns a bytestring containing the first packet's bytes. + + Generally, you want to use packets() in case the response + does not fit in a single packet, but this exists for + backward compatibility.""" + packets = self.packets() + if len(packets) > 0: + if len(packets[0]) > _MAX_MSG_ABSOLUTE: + QuietLogger.log_warning_once( + "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] + ) + return packets[0] + else: + return b'' - No further parts should be added to the packet once this - is done.""" + def packets(self) -> List[bytes]: + """Returns a list of bytestrings containing the packets' bytes - overrun_answers, overrun_authorities, overrun_additionals = 0, 0, 0 + No further parts should be added to the packet once this + is done. The packets are each restricted to _MAX_MSG_TYPICAL + or less in length, except for the case of a single answer which + will be written out to a single oversized packet no more than + _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP + fragmentation potentially). """ - if self.state != self.State.finished: + if self.state == self.State.finished: + return self.packets_data + + answer_offset = 0 + authority_offset = 0 + additional_offset = 0 + + # we have to at least write out the question + first_time = True + + while ( + first_time + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ): + first_time = False + log.debug("offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset) + log.debug("lengths = %d, %d, %d", len(self.answers), len(self.authorities), len(self.additionals)) + + additionals_written = 0 + authorities_written = 0 + answers_written = 0 + questions_written = 0 for question in self.questions: self.write_question(question) - for answer, time_ in self.answers: - overrun_answers += self.write_record(answer, time_) - for authority in self.authorities: - overrun_authorities += self.write_record(authority, 0) - for additional in self.additionals: - overrun_additionals += self.write_record(additional, 0) - self.state = self.State.finished - - self.insert_short(0, len(self.additionals) - overrun_additionals) - self.insert_short(0, len(self.authorities) - overrun_authorities) - self.insert_short(0, len(self.answers) - overrun_answers) - self.insert_short(0, len(self.questions)) + questions_written += 1 + allow_long = True # at most one answer is allowed to be a long packet + for answer, time_ in self.answers[answer_offset:]: + if self.write_record(answer, time_, allow_long): + answers_written += 1 + allow_long = False + for authority in self.authorities[authority_offset:]: + if self.write_record(authority, 0): + authorities_written += 1 + for additional in self.additionals[additional_offset:]: + if self.write_record(additional, 0): + additionals_written += 1 + + self.insert_short(0, additionals_written) + self.insert_short(0, authorities_written) + self.insert_short(0, answers_written) + self.insert_short(0, questions_written) self.insert_short(0, self.flags) if self.multicast: self.insert_short(0, 0) else: self.insert_short(0, self.id) - return b''.join(self.data) + self.packets_data.append(b''.join(self.data)) + self.reset_for_next_packet() + + answer_offset += answers_written + authority_offset += authorities_written + additional_offset += additionals_written + log.debug("now offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset) + if answers_written == 0 and authorities_written == 0 and additional_offset == 0: + log.warning("packets() made no progress adding records; returning") + break + self.state = self.State.finished + return self.packets_data class DNSCache: @@ -2684,36 +2753,39 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: """Sends an outgoing packet.""" - packet = out.packet() - if len(packet) > _MAX_MSG_ABSOLUTE: - self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) - return - log.debug('Sending %r (%d bytes) as [%r]', out, len(packet), packet) - for s in self._respond_sockets: - if self._GLOBAL_DONE: + packets = out.packets() + packet_num = 0 + for packet in packets: + packet_num += 1 + if len(packet) > _MAX_MSG_ABSOLUTE: + self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return - try: - if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR - elif not can_send_to(s, addr): - continue + log.debug('Sending (%d bytes #%d) %r as %r...', len(packet), packet_num, out, packet) + for s in self._respond_sockets: + if self._GLOBAL_DONE: + return + try: + if addr is None: + real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR + elif not can_send_to(s, addr): + continue + else: + real_addr = addr + bytes_sent = s.sendto(packet, 0, (real_addr, port)) + except Exception as exc: # TODO stop catching all Exceptions + if ( + isinstance(exc, OSError) + and exc.errno == errno.ENETUNREACH + and s.family == socket.AF_INET6 + ): + # with IPv6 we don't have a reliable way to determine if an interface actually has + # IPV6 support, so we have to try and ignore errors. + continue + # on send errors, log the exception and keep going + self.log_exception_warning() else: - real_addr = addr - bytes_sent = s.sendto(packet, 0, (real_addr, port)) - except Exception as exc: # TODO stop catching all Exceptions - if ( - isinstance(exc, OSError) - and exc.errno == errno.ENETUNREACH - and s.family == socket.AF_INET6 - ): - # with IPv6 we don't have a reliable way to determine if an interface actually has IPv6 - # support, so we have to try and ignore errors. - continue - # on send errors, log the exception and keep going - self.log_exception_warning() - else: - if bytes_sent != len(packet): - self.log_warning_once('!!! sent %d out of %d bytes to %r' % (bytes_sent, len(packet), s)) + if bytes_sent != len(packet): + self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) def close(self) -> None: """Ends the background threads, and prevent this instance from diff --git a/zeroconf/test.py b/zeroconf/test.py index 2bc642c8..3440e5e8 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -318,6 +318,13 @@ def test_exceedingly_long_name(self): generated.add_question(question) r.DNSIncoming(generated.packet()) + def test_extra_exceedingly_long_name(self): + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + name = "%slocal." % ("part." * 4000) + question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + r.DNSIncoming(generated.packet()) + def test_exceedingly_long_name_part(self): name = "%s.local." % ("a" * 1000) generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) @@ -355,12 +362,12 @@ def test_lots_of_names(self): def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): """Sends an outgoing packet.""" - packet = out.packet() - nonlocal longest_packet_len, longest_packet - if longest_packet_len < len(packet): - longest_packet_len = len(packet) - longest_packet = out - old_send(out, addr=addr, port=port) + for packet in out.packets(): + nonlocal longest_packet_len, longest_packet + if longest_packet_len < len(packet): + longest_packet_len = len(packet) + longest_packet = out + old_send(out, addr=addr, port=port) # monkey patch the zeroconf send setattr(zc, "send", send) @@ -374,6 +381,9 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # wait until the browse request packet has maxed out in size sleep_count = 0 + # we will never get to this large of a packet given the application-layer + # splitting of packets, but we still want to track the longest_packet_len + # for the debug message below while sleep_count < 100 and longest_packet_len < r._MAX_MSG_ABSOLUTE - 100: sleep_count += 1 time.sleep(0.1) @@ -386,8 +396,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) # now the browser has sent at least one request, verify the size - assert longest_packet_len <= r._MAX_MSG_ABSOLUTE - assert longest_packet_len >= r._MAX_MSG_ABSOLUTE - 100 + assert longest_packet_len <= r._MAX_MSG_TYPICAL + assert longest_packet_len >= r._MAX_MSG_TYPICAL - 100 # mock zeroconf's logger warning() and debug() from unittest.mock import patch @@ -407,13 +417,11 @@ def on_service_state_change(zeroconf, service_type, state_change, name): call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count # try to send an oversized packet zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] + 1 - assert mocked_log_debug.call_count == call_counts[0] + assert mocked_log_warn.call_count == call_counts[0] zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] + 1 - assert mocked_log_debug.call_count == call_counts[0] + 1 + assert mocked_log_warn.call_count == call_counts[0] - # force a receive of an oversized packet + # force a receive of a packet packet = out.packet() s = zc._respond_sockets[0] From 488ee1e85762dc5856d8e132da54762e5e712c5a Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 25 May 2020 23:43:06 +0200 Subject: [PATCH 035/608] Warn on every call to missing update_service() listener method This is in order to provide visibility to the library users that this method exists - without it the client code may be missing data. --- zeroconf/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9fa6ae5c..48760f67 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1521,6 +1521,12 @@ def on_change( elif state_change is ServiceStateChange.Updated: if hasattr(listener, 'update_service'): listener.update_service(*args) + else: + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) else: raise NotImplementedError(state_change) From 178cec75bd9a065b150b3542dfdb40682f6745b6 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 26 May 2020 23:29:51 +0200 Subject: [PATCH 036/608] Restore missing warnings import --- zeroconf/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 48760f67..6df3a806 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -34,6 +34,7 @@ import sys import threading import time +import warnings from collections import OrderedDict from typing import Dict, List, Optional, Sequence, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints From 781ac834da38708d95bfe6e5f5ec7dd0f31efc54 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 26 May 2020 23:55:08 +0200 Subject: [PATCH 037/608] Add --find option to example/browser.py (#263, rebased #175) Co-authored-by: Perry Kundert --- examples/browser.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/browser.py b/examples/browser.py index c4ddac39..624aab9f 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 -""" Example of browsing for a service (in this case, HTTP) """ +""" Example of browsing for a service. + +The default is HTTP and HAP; use --find to search for all available services in the network +""" import argparse import logging @@ -8,7 +11,7 @@ from time import sleep from typing import cast -from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf +from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf, ZeroconfServiceTypes def on_service_state_change( @@ -18,6 +21,7 @@ def on_service_state_change( if state_change is ServiceStateChange.Added: info = zeroconf.get_service_info(service_type, name) + print("Info from zeroconf.get_service_info: %r" % (info)) if info: addresses = ["%s:%d" % (socket.inet_ntoa(addr), cast(int, info.port)) for addr in info.addresses] print(" Addresses: %s" % ", ".join(addresses)) @@ -39,6 +43,7 @@ def on_service_state_change( parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true') + parser.add_argument('--find', action='store_true', help='Browse all available services') version_group = parser.add_mutually_exclusive_group() version_group.add_argument('--v6', action='store_true') version_group.add_argument('--v6-only', action='store_true') @@ -54,10 +59,13 @@ def on_service_state_change( ip_version = IPVersion.V4Only zeroconf = Zeroconf(ip_version=ip_version) - print("\nBrowsing services, press Ctrl-C to exit...\n") - browser = ServiceBrowser( - zeroconf, ["_http._tcp.local.", "_hap._tcp.local."], handlers=[on_service_state_change] - ) + + services = ["_http._tcp.local.", "_hap._tcp.local."] + if args.find: + services = list(ZeroconfServiceTypes.find(zc=zeroconf)) + + print("\nBrowsing %d service(s), press Ctrl-C to exit...\n" % len(services)) + browser = ServiceBrowser(zeroconf, services, handlers=[on_service_state_change]) try: while True: From d881abaf591f260ad019f4ff86e7f70a6f018a64 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 27 May 2020 20:17:31 +0200 Subject: [PATCH 038/608] Remove no longer needed typing dependency We don't support Python older than 3.5. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 57b8b0be..09293206 100755 --- a/setup.py +++ b/setup.py @@ -43,5 +43,5 @@ 'Programming Language :: Python :: Implementation :: PyPy', ], keywords=['Bonjour', 'Avahi', 'Zeroconf', 'Multicast DNS', 'Service Discovery', 'mDNS'], - install_requires=['ifaddr', 'typing;python_version<"3.5"'], + install_requires=['ifaddr'], ) From 0502f1904b0a8b9134ea2a09333232b30b3b6897 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 27 May 2020 22:02:25 +0200 Subject: [PATCH 039/608] Release version 0.27.0 --- README.rst | 14 ++++++++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 84c45906..d46d64b0 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,20 @@ See examples directory for more. Changelog ========= +0.27.0 +------ + +* Large multi-resource responses are now split into separate packets which fixes a bad + mdns-repeater/ChromeCast Audio interaction ending with ChromeCast Audio crash (and possibly + some others) and improves RFC 6762 compliance, thanks to Greg Badros +* Added a warning presented when the listener passed to ServiceBrowser lacks update_service() + callback +* Added support for finding all services available in the browser example, thanks to Perry Kunder + +Backwards incompatible: + +* Removed previously deprecated ServiceInfo address constructor parameter and property + 0.26.3 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 6df3a806..f22a7749 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.26.3' +__version__ = '0.27.0' __license__ = 'LGPL' From 6f876a7f14f0b172860005b0d6d959d82f7c1bbf Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 28 May 2020 19:54:46 +0200 Subject: [PATCH 040/608] Remove old Python 2-specific code --- zeroconf/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f22a7749..a1af17b8 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2111,8 +2111,7 @@ def new_socket( else: try: s.setsockopt(socket.SOL_SOCKET, reuseport, 1) - except (OSError, socket.error) as err: - # OSError on python 3, socket.error on python 2 + except OSError as err: if not err.errno == errno.ENOPROTOOPT: raise From 8045191ae6300da47d38e5cd82957965139359d2 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 28 May 2020 19:56:28 +0200 Subject: [PATCH 041/608] Improve ImportError message (wrong supported Python version) --- zeroconf/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a1af17b8..a83e7635 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -58,11 +58,11 @@ "IPVersion", ] -if sys.version_info <= (3, 3): +if sys.version_info <= (3, 4): raise ImportError( ''' -Python version > 3.3 required for python-zeroconf. -If you need support for Python 2 or Python 3.3 please use version 19.1 +Python version > 3.4 required for python-zeroconf. +If you need support for Python 2 or Python 3.3-3.4 please use version 19.1 ''' ) From d6593af2a3811b262d70bbc75c2c91613de41b21 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 29 May 2020 01:00:47 +0200 Subject: [PATCH 042/608] Simplify DNSHinfo constructor, cpu and os are always text (#266) --- zeroconf/__init__.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a83e7635..5f632be1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -554,18 +554,10 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" - def __init__( - self, name: str, type_: int, class_: int, ttl: int, cpu: Union[bytes, str], os: Union[bytes, str] - ) -> None: + def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: DNSRecord.__init__(self, name, type_, class_, ttl) - try: - self.cpu = cast(bytes, cpu).decode('utf-8') - except AttributeError: - self.cpu = cast(str, cpu) - try: - self.os = cast(bytes, os).decode('utf-8') - except AttributeError: - self.os = cast(str, os) + self.cpu = cpu + self.os = os def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -807,7 +799,12 @@ def read_others(self) -> None: ) elif type_ == _TYPE_HINFO: rec = DNSHinfo( - domain, type_, class_, ttl, self.read_character_string(), self.read_character_string() + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), ) elif type_ == _TYPE_AAAA: rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) From beff99897f0a5ece17e224a7ea9b12ebd420044f Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 31 May 2020 14:49:45 +0200 Subject: [PATCH 043/608] Improve logging (mainly include sockets in some messages) (#271) --- zeroconf/__init__.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5f632be1..17e6ef2f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -353,7 +353,7 @@ class QuietLogger: _seen_logs = {} # type: Dict[str, Union[int, tuple]] @classmethod - def log_exception_warning(cls, logger_data: Optional[Tuple] = None) -> None: + def log_exception_warning(cls, *logger_data: Any) -> None: exc_info = sys.exc_info() exc_str = str(exc_info[1]) if exc_str not in cls._seen_logs: @@ -362,9 +362,7 @@ def log_exception_warning(cls, logger_data: Optional[Tuple] = None) -> None: logger = log.warning else: logger = log.debug - if logger_data is not None: - logger(*logger_data) - logger('Exception occurred:', exc_info=True) + logger(*(logger_data or ['Exception occurred']), exc_info=True) @classmethod def log_warning_once(cls, *args: Any) -> None: @@ -709,7 +707,7 @@ def __init__(self, data: bytes) -> None: self.valid = True except (IndexError, struct.error, IncomingDecodeError): - self.log_exception_warning(('Choked at offset %d while unpacking %r', self.offset, data)) + self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) def __repr__(self) -> str: return '' % ', '.join( @@ -1357,15 +1355,30 @@ def handle_read(self, socket_: socket.socket) -> None: try: data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) except Exception: - self.log_exception_warning() + self.log_exception_warning('Error reading from socket %d', socket_.fileno()) return self.data = data msg = DNSIncoming(data) if msg.valid: - log.debug('Received from %r:%r: %r (%d bytes) as [%r]', addr, port, msg, len(data), data) + log.debug( + 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', + addr, + port, + socket_.fileno(), + msg, + len(data), + data, + ) else: - log.debug('Received from %r:%r: (%d bytes) [%r]', addr, port, len(data), data) + log.debug( + 'Received from %r:%r (socket %d): (%d bytes) [%r]', + addr, + port, + socket_.fileno(), + len(data), + data, + ) if not msg.valid: pass @@ -2139,7 +2152,7 @@ def add_multicast_member( ) -> Optional[socket.socket]: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, int) - log.debug('Adding %r to multicast group', interface) + log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: iface_bin = struct.pack('@I', cast(int, interface)) @@ -2173,7 +2186,7 @@ def add_multicast_member( respond_socket = new_socket( ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p ) - log.debug('Configuring %s with multicast interface %s', respond_socket, interface) + log.debug('Configuring socket %d with multicast interface %s', respond_socket, interface) if is_v6: respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) else: @@ -2785,7 +2798,7 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P # IPV6 support, so we have to try and ignore errors. continue # on send errors, log the exception and keep going - self.log_exception_warning() + self.log_exception_warning('Error sending through socket %d', s.fileno()) else: if bytes_sent != len(packet): self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) From 10065b976247ae9247cddaff8f3e9d7b331e66d7 Mon Sep 17 00:00:00 2001 From: gjbadros Date: Tue, 2 Jun 2020 01:20:07 -0700 Subject: [PATCH 044/608] Fix false warning (#273) When there is nothing to write, we don't need to warn about not making progress. --- zeroconf/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 17e6ef2f..44831f93 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1184,7 +1184,9 @@ def packets(self) -> List[bytes]: authority_offset += authorities_written additional_offset += additionals_written log.debug("now offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset) - if answers_written == 0 and authorities_written == 0 and additional_offset == 0: + if (answers_written + authorities_written + additionals_written) == 0 and ( + len(self.answers) + len(self.authorities) + len(self.additionals) + ) > 0: log.warning("packets() made no progress adding records; returning") break self.state = self.State.finished From 0538abf135f5502d94dd883475bcb2781ce5ddd2 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 5 Jun 2020 11:09:58 +0200 Subject: [PATCH 045/608] Release version 0.27.1 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index d46d64b0..b2318d6a 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.27.1 +------ + +* Improved the logging situation (includes fixing a false-positive "packets() made no progress + adding records", thanks to Greg Badros) + 0.27.0 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 44831f93..b212d1b8 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -43,7 +43,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.27.0' +__version__ = '0.27.1' __license__ = 'LGPL' From c31ae7fd519df04f41939d3c60c2b88960737fd6 Mon Sep 17 00:00:00 2001 From: Sandy Patterson Date: Fri, 5 Jun 2020 16:41:53 -0400 Subject: [PATCH 046/608] Support Windows when using socket errno checks (#274) Windows reports errno.WSAEINVAL(10022) instead of errno.EINVAL(22). This issue is triggered when a device has two IP's assigned under windows. This fixes #189 --- zeroconf/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b212d1b8..25d81990 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2154,6 +2154,9 @@ def add_multicast_member( ) -> Optional[socket.socket]: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, int) + err_einval = {errno.EINVAL} + if sys.platform == 'win32': + err_einval |= {errno.WSAEINVAL} log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: @@ -2179,7 +2182,7 @@ def add_multicast_member( interface, ) return None - elif _errno == errno.EINVAL: + elif _errno in err_einval: log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) return None else: From 0a9aa8d31bffec5d7b7291b84fbc95222b10d189 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 26 May 2020 13:34:43 +0200 Subject: [PATCH 047/608] Add support for passing text addresses to ServiceInfo Not sure if parsed_addresses is the best way to name the parameter, but we already have a parsed_addresses property so for the sake of consistency let's stick to that. --- zeroconf/__init__.py | 18 +++++++++++++++--- zeroconf/test.py | 45 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 25d81990..ed5b98f0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -219,6 +219,12 @@ def _is_v6_address(addr: bytes) -> bool: return len(addr) == 16 +def _encode_address(address: str) -> bytes: + is_ipv6 = ':' in address + address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + return socket.inet_pton(address_family, address) + + def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -1696,8 +1702,8 @@ class ServiceInfo(RecordUpdateListener): * server: fully qualified name for service host (defaults to name) * host_ttl: ttl used for A/SRV records * other_ttl: ttl used for PTR/TXT records - * addresses: List of IP addresses as unsigned short (IPv4) or unsigned 128 bit number (IPv6), - network byte order + * addresses and parsed_addresses: List of IP addresses (either as bytes, network byte order, or in parsed + form as text; at most one of those parameters can be provided) """ @@ -1718,14 +1724,20 @@ def __init__( host_ttl: int = _DNS_HOST_TTL, other_ttl: int = _DNS_OTHER_TTL, *, - addresses: Optional[List[bytes]] = None + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") if not type_.endswith(service_type_name(name, allow_underscores=True)): raise BadTypeInNameException self.type = type_ self.name = name if addresses is not None: self._addresses = addresses + elif parsed_addresses is not None: + self._addresses = [_encode_address(a) for a in parsed_addresses] else: self._addresses = [] # This results in an ugly error when registering, better check now diff --git a/zeroconf/test.py b/zeroconf/test.py index 3440e5e8..06db1ed6 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1488,19 +1488,44 @@ def test_multiple_addresses(): assert info.addresses == [address, address] + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_parsed], + ) + assert info.addresses == [address, address] + if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address_v6], - ) - assert info.addresses == [address] - assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] - assert info.addresses_by_version(r.IPVersion.V4Only) == [address] - assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] - assert info.parsed_addresses() == [address_parsed, address_v6_parsed] - assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] + infos = [ + ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address_v6], + ), + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_v6_parsed], + ), + ] + for info in infos: + assert info.addresses == [address] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.V4Only) == [address] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed] + assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] def test_ptr_optimization(): From 328abfc54138e68e36a9f5381650bd6997701e73 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 12 Jun 2020 22:43:48 +0200 Subject: [PATCH 048/608] Fix one log format string (we use a socket object here) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ed5b98f0..daf723ac 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2203,7 +2203,7 @@ def add_multicast_member( respond_socket = new_socket( ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p ) - log.debug('Configuring socket %d with multicast interface %s', respond_socket, interface) + log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) if is_v6: respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) else: From 3b6906ab94f8d9ebeb1c97b6026ab7f9be226eab Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 17 Jun 2020 01:23:48 +0200 Subject: [PATCH 049/608] Log listen and respond sockets just in case --- zeroconf/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index daf723ac..e2018f34 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2304,6 +2304,7 @@ def __init__( self._listen_socket, self._respond_sockets = create_sockets( interfaces, unicast, ip_version, apple_p2p=apple_p2p ) + log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) self.listeners = [] # type: List[RecordUpdateListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] From 023e72d821faed9513ee0ef3a22a00231d87389e Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 17 Jun 2020 01:37:11 +0200 Subject: [PATCH 050/608] Exclude a problematic pep8-naming version --- requirements-dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 127df74e..2d0490ae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,6 +5,7 @@ coverage flake8>=3.6.0 flake8-import-order ifaddr -pep8-naming!=0.6.0 +# 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152 +pep8-naming!=0.6.0,!=0.11.0 pytest pytest-cov From 64056ab4aa55eb11c185c9879462ba1f82c7e886 Mon Sep 17 00:00:00 2001 From: PhilippSelenium <31542906+PhilippSelenium@users.noreply.github.com> Date: Mon, 29 Jun 2020 15:25:05 +0200 Subject: [PATCH 051/608] Use Adapter.index from ifaddr. (#280) Co-authored-by: PhilippSelenium --- setup.py | 2 +- zeroconf/__init__.py | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 09293206..f90d3175 100755 --- a/setup.py +++ b/setup.py @@ -43,5 +43,5 @@ 'Programming Language :: Python :: Implementation :: PyPy', ], keywords=['Bonjour', 'Avahi', 'Zeroconf', 'Multicast DNS', 'Service Discovery', 'mDNS'], - install_requires=['ifaddr'], + install_requires=['ifaddr>=0.1.7'], ) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e2018f34..dc33082e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -25,7 +25,6 @@ import ipaddress import itertools import logging -import os import platform import re import select @@ -2032,17 +2031,12 @@ def get_all_addresses_v6() -> List[int]: def ip_to_index(adapters: List[Any], ip: str) -> int: - if os.name != 'posix': - # Adapter names that ifaddr reports are not compatible with what if_nametoindex expects on Windows. - # We need https://github.com/pydron/ifaddr/pull/21 but it seems stuck on review. - raise RuntimeError('Converting from IP addresses to indexes is not supported on non-POSIX systems') - ipaddr = ipaddress.ip_address(ip) for adapter in adapters: for adapter_ip in adapter.ips: # IPv6 addresses are represented as tuples if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: - return socket.if_nametoindex(adapter.name) + return adapter.index raise RuntimeError('No adapter found for IP address %s' % ip) @@ -2050,8 +2044,7 @@ def ip_to_index(adapters: List[Any], ip: str) -> int: def ip6_addresses_to_indexes(interfaces: List[Union[str, int]]) -> List[int]: """Convert IPv6 interface addresses to interface indexes. - IPv4 addresses are ignored. The conversion currently only works on POSIX - systems. + IPv4 addresses are ignored. :param interfaces: List of IP addresses and indexes. :returns: List of indexes. @@ -2271,7 +2264,6 @@ def __init__( (IPv4 and IPv6) and interface indexes (IPv6 only). IPv6 notes for non-POSIX systems: - * IPv6 addresses are not supported, use indexes instead. * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` on Python versions before 3.8. From 4381784150e07625b4acd2034b253bf2ed320c5f Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 12:25:42 +0200 Subject: [PATCH 052/608] Make Mypy happy (#281) Otherwise it'd complain: % make mypy mypy examples/*.py zeroconf/*.py zeroconf/__init__.py:2039: error: Returning Any from function declared to return "int" Found 1 error in 1 file (checked 6 source files) make: *** [mypy] Error 1 --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dc33082e..cef46ab1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2036,7 +2036,7 @@ def ip_to_index(adapters: List[Any], ip: str) -> int: for adapter_ip in adapter.ips: # IPv6 addresses are represented as tuples if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: - return adapter.index + return cast(int, adapter.index) raise RuntimeError('No adapter found for IP address %s' % ip) From a7f9823cbed254b506a09cc514d86d9f5dc61ad3 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 12:50:06 +0200 Subject: [PATCH 053/608] Stop using socket.if_nameindex (#282) This improves Windows compatibility --- zeroconf/__init__.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index cef46ab1..ed79350d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2018,16 +2018,7 @@ def get_all_addresses() -> List[str]: def get_all_addresses_v6() -> List[int]: # IPv6 multicast uses positive indexes for interfaces - try: - nameindex = socket.if_nameindex - except AttributeError: - # Requires Python 3.8 on Windows. Fall back to Default. - QuietLogger.log_warning_once( - 'if_nameindex is not available, falling back to using the default IPv6 interface' - ) - return [0] - - return [tpl[0] for tpl in nameindex()] + return [adapter.index for adapter in ifaddr.get_adapters()] def ip_to_index(adapters: List[Any], ip: str) -> int: From fc92b1e2635868792aa7ebe937a9cfef2e2f0418 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 13:11:44 +0200 Subject: [PATCH 054/608] Fix an OS X edge case (#270, #188) This contains two major changes: * Listen on data from respond_sockets in addition to listen_socket * Do not bind respond sockets to 0.0.0.0 or ::/0 The description of the original change by Emil: <<< Without either of these changes, I get no replies at all when browsing for services using the browser example. I'm on a corporate network, and when connecting to a different network it works without these changes, so maybe it's something about the network configuration in this particular network that breaks the previous behavior. Unfortunately, I have no idea how this affects other platforms, or what the changes really mean. However, it works for me and it seems reasonable to get replies back on the same socket where they are sent. >>> The tests pass and it's been confirmed to a reasonable degree that this doesn't break the previously working use cases. Additionally this removes a memory leak where data sent to some of the respond sockets would not be ever read from them (#171). Co-authored-by: Emil Styrke --- zeroconf/__init__.py | 88 ++++++++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 27 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ed79350d..3a1c5f66 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -189,7 +189,7 @@ class InterfaceChoice(enum.Enum): All = 2 -InterfacesType = Union[List[Union[str, int]], InterfaceChoice] +InterfacesType = Union[List[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice] @enum.unique @@ -2016,23 +2016,39 @@ def get_all_addresses() -> List[str]: return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4)) -def get_all_addresses_v6() -> List[int]: +def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: # IPv6 multicast uses positive indexes for interfaces - return [adapter.index for adapter in ifaddr.get_adapters()] + # TODO: What about multi-address interfaces? + return list( + set((addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6) + ) -def ip_to_index(adapters: List[Any], ip: str) -> int: +def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: ipaddr = ipaddress.ip_address(ip) for adapter in adapters: for adapter_ip in adapter.ips: # IPv6 addresses are represented as tuples if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: - return cast(int, adapter.index) + return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index)) raise RuntimeError('No adapter found for IP address %s' % ip) -def ip6_addresses_to_indexes(interfaces: List[Union[str, int]]) -> List[int]: +def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: + for adapter in adapters: + if adapter.index == index: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple): + return cast(Tuple[str, int, int], adapter_ip.ip) + + raise RuntimeError('No adapter found for index %s' % index) + + +def ip6_addresses_to_indexes( + interfaces: List[Union[str, int, Tuple[Tuple[str, int, int], int]]] +) -> List[Tuple[Tuple[str, int, int], int]]: """Convert IPv6 interface addresses to interface indexes. IPv4 addresses are ignored. @@ -2045,27 +2061,27 @@ def ip6_addresses_to_indexes(interfaces: List[Union[str, int]]) -> List[int]: for iface in interfaces: if isinstance(iface, int): - result.append(iface) + result.append((interface_index_to_ip6_address(adapters, iface), iface)) elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: - result.append(ip_to_index(adapters, iface)) + result.append(ip6_to_address_and_index(adapters, iface)) return result def normalize_interface_choice( choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> List[Union[str, int]]: +) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: """Convert the interfaces choice into internal representation. :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). :param ip_address: IP version to use (ignored if `choice` is a list). :returns: List of IP addresses (for IPv4) and indexes (for IPv6). """ - result = [] # type: List[Union[str, int]] + result = [] # type: List[Union[str, Tuple[Tuple[str, int, int], int]]] if choice is InterfaceChoice.Default: if ip_version != IPVersion.V4Only: # IPv6 multicast uses interface 0 to mean the default - result.append(0) + result.append((('', 0, 0), 0)) if ip_version != IPVersion.V6Only: result.append('0.0.0.0') elif choice is InterfaceChoice.All: @@ -2088,8 +2104,18 @@ def normalize_interface_choice( def new_socket( - port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, apple_p2p: bool = False + bind_addr: Union[Tuple[str], Tuple[str, int, int]], + port: int = _MDNS_PORT, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, ) -> socket.socket: + log.debug( + 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r', + port, + ip_version, + apple_p2p, + bind_addr, + ) if ip_version == IPVersion.V4Only: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) else: @@ -2141,22 +2167,25 @@ def new_socket( # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) - s.bind(('', port)) + s.bind((bind_addr[0], port, *bind_addr[1:])) + log.debug('Created socket %s', s) return s def add_multicast_member( - listen_socket: socket.socket, interface: Union[str, int], apple_p2p: bool = False + listen_socket: socket.socket, + interface: Union[str, Tuple[Tuple[str, int, int], int]], + apple_p2p: bool = False, ) -> Optional[socket.socket]: # This is based on assumptions in normalize_interface_choice - is_v6 = isinstance(interface, int) + is_v6 = isinstance(interface, tuple) err_einval = {errno.EINVAL} if sys.platform == 'win32': err_einval |= {errno.WSAEINVAL} log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: - iface_bin = struct.pack('@I', cast(int, interface)) + iface_bin = struct.pack('@I', cast(int, interface[1])) _value = _MDNS_ADDR6_BYTES + iface_bin listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) else: @@ -2185,7 +2214,9 @@ def add_multicast_member( raise respond_socket = new_socket( - ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p + ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), + apple_p2p=apple_p2p, + bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), ) log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) if is_v6: @@ -2206,17 +2237,22 @@ def create_sockets( if unicast: listen_socket = None else: - listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p) + listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',)) - interfaces = normalize_interface_choice(interfaces, ip_version) + normalized_interfaces = normalize_interface_choice(interfaces, ip_version) respond_sockets = [] - for i in interfaces: + for i in normalized_interfaces: if not unicast: respond_socket = add_multicast_member(cast(socket.socket, listen_socket), i, apple_p2p=apple_p2p) else: - respond_socket = new_socket(port=0, ip_version=ip_version, apple_p2p=apple_p2p) + respond_socket = new_socket( + port=0, + ip_version=ip_version, + apple_p2p=apple_p2p, + bind_addr=i[0] if isinstance(i, tuple) else (i,), + ) if respond_socket is not None: respond_sockets.append(respond_socket) @@ -2307,9 +2343,8 @@ def __init__( self.listener = Listener(self) if not unicast: self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) - else: - for s in self._respond_sockets: - self.engine.add_reader(self.listener, s) + for s in self._respond_sockets: + self.engine.add_reader(self.listener, s) self.reaper = Reaper(self) self.debug = None # type: Optional[DNSOutgoing] @@ -2817,9 +2852,8 @@ def close(self) -> None: if not self.unicast: self.engine.del_reader(cast(socket.socket, self._listen_socket)) cast(socket.socket, self._listen_socket).close() - else: - for s in self._respond_sockets: - self.engine.del_reader(s) + for s in self._respond_sockets: + self.engine.del_reader(s) self.engine.join() # shutdown the rest From 02bcad902c516a5a2d2aa3302bca9871900da6e3 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 13:21:57 +0200 Subject: [PATCH 055/608] Advertise Python 3.8 compatibility --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f90d3175..ac24ca7f 100755 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', ], From 0fdbf5e197a9f76e9e9c91a5e0908a0c66370dbd Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 13:21:44 +0200 Subject: [PATCH 056/608] Release version 0.28.0 --- README.rst | 14 ++++++++++++++ zeroconf/__init__.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b2318d6a..27925eb9 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,20 @@ See examples directory for more. Changelog ========= +0.28.0 +====== + +* Improved Windows support when using socket errno checks, thanks to Sandy Patterson. +* Added support for passing text addresses to ServiceInfo. +* Improved logging (includes fixing an incorrect logging call) +* Improved Windows compatibility by using Adapter.index from ifaddr, thanks to PhilippSelenium. +* Improved Windows compatibility by stopping using socket.if_nameindex. +* Fixed an OS X edge case which should also eliminate a memory leak, thanks to Emil Styrke. + +Technically backwards incompatible: + +* ``ifaddr`` 0.1.7 or newer is required now. + 0.27.1 ------ diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3a1c5f66..07e62845 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.27.1' +__version__ = '0.28.0' __license__ = 'LGPL' From 19e33a6829846008b50f408c77ac3e8e73176529 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 7 Jul 2020 13:25:26 +0200 Subject: [PATCH 057/608] Gitignore some build artifacts --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index eac2c170..0af9ce1e 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ Thumbs.db .mypy_cache/ docs/_build/ .vscode +/dist/ +/zeroconf.egg-info/ From c9f3c91da568fdbd26d571eed8a636a49e527b15 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 17 Aug 2020 15:21:04 -0500 Subject: [PATCH 058/608] Ensure all listeners are cleaned up on ServiceBrowser cancelation (#290) When creating listeners for a ServiceBrowser with multiple types they would not all be removed on cancelation. This led to a build up of stale listeners when ServiceBrowsers were frequently added and removed. --- zeroconf/__init__.py | 32 ++++++++++++++--------- zeroconf/test.py | 61 +++++++++++++++++++++++++++++--------------- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 07e62845..f1f18741 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -174,6 +174,10 @@ _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') _HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') +_EXPIRE_FULL_TIME_PERCENT = 100 +_EXPIRE_STALE_TIME_PERCENT = 50 +_EXPIRE_REFRESH_TIME_PERCENT = 75 + try: _IPPROTO_IPV6 = socket.IPPROTO_IPV6 except AttributeError: @@ -459,8 +463,8 @@ def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) - DNSEntry.__init__(self, name, type_, class_) self.ttl = ttl self.created = current_time_millis() - self._expiration_time = self.get_expiration_time(100) - self._stale_time = self.get_expiration_time(50) + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) + self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) def __eq__(self, other: Any) -> bool: """Abstract method""" @@ -506,8 +510,8 @@ def reset_ttl(self, other: 'DNSRecord') -> None: another record.""" self.created = other.created self.ttl = other.ttl - self._expiration_time = self.get_expiration_time(100) - self._stale_time = self.get_expiration_time(50) + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) + self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) def write(self, out: 'DNSOutgoing') -> None: """Abstract method""" @@ -1609,7 +1613,7 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) return - expires = record.get_expiration_time(75) + expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) if expires < self._next_time[record.name]: self._next_time[record.name] = expires @@ -1649,8 +1653,8 @@ def cancel(self) -> None: self.join() def run(self) -> None: - for type_ in self.types: - self.zc.add_listener(self, DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) + questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] + self.zc.add_listener(self, questions) while True: now = current_time_millis() @@ -2595,16 +2599,20 @@ def check_service( i += 1 next_time += _CHECK_TIME - def add_listener(self, listener: RecordUpdateListener, question: Optional[DNSQuestion]) -> None: + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to - answer the question.""" + answer the question(s).""" now = current_time_millis() self.listeners.append(listener) if question is not None: - for record in self.cache.entries_with_name(question.name): - if question.answered_by(record) and not record.is_expired(now): - listener.update_record(self, now, record) + questions = [question] if isinstance(question, DNSQuestion) else question + for single_question in questions: + for record in self.cache.entries_with_name(single_question.name): + if single_question.answered_by(record) and not record.is_expired(now): + listener.update_record(self, now, record) self.notify_all() def remove_listener(self, listener: RecordUpdateListener) -> None: diff --git a/zeroconf/test.py b/zeroconf/test.py index 06db1ed6..96be62cf 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -24,6 +24,7 @@ ServiceStateChange, Zeroconf, ZeroconfServiceTypes, + _EXPIRE_REFRESH_TIME_PERCENT, ) log = logging.getLogger('zeroconf') @@ -1237,7 +1238,9 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi assert service_removed_count == 1 finally: + assert len(zeroconf.listeners) == 1 service_browser.cancel() + assert len(zeroconf.listeners) == 0 zeroconf.remove_all_service_listeners() zeroconf.close() @@ -1245,8 +1248,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi class TestServiceBrowserMultipleTypes(unittest.TestCase): def test_update_record(self): - service_names = ['name._type._tcp.local.', 'name._type._udp.local'] - service_types = ['_type._tcp.local.', '_type._udp.local.'] + service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] + service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] service_added_count = 0 service_removed_count = 0 @@ -1257,25 +1260,19 @@ class MyServiceListener(r.ServiceListener): def add_service(self, zc, type_, name) -> None: nonlocal service_added_count service_added_count += 1 - if service_added_count == 2: + if service_added_count == 3: service_add_event.set() def remove_service(self, zc, type_, name) -> None: nonlocal service_removed_count service_removed_count += 1 - if service_removed_count == 2: + if service_removed_count == 3: service_removed_event.set() def mock_incoming_msg( - service_state_change: r.ServiceStateChange, service_type: str, service_name: str + service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int ) -> r.DNSIncoming: generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - else: - ttl = 120 - generated.add_answer_at_time( r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 ) @@ -1287,30 +1284,54 @@ def mock_incoming_msg( try: wait_time = 3 - # both services added + # all three services added + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120) + ) zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0]) + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120) ) zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1]) + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) ) + + called_with_refresh_time_check = False + + def _mock_get_expiration_time(self, percent): + nonlocal called_with_refresh_time_check + if percent == _EXPIRE_REFRESH_TIME_PERCENT: + called_with_refresh_time_check = True + return 0 + return self.created + (percent * self.ttl * 10) + + # Set an expire time that will force a refresh + with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) + ) service_add_event.wait(wait_time) - assert service_added_count == 2 + assert called_with_refresh_time_check is True + assert service_added_count == 3 assert service_removed_count == 0 - # both services removed + # all three services removed + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0) + ) zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0]) + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0) ) zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1]) + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0) ) service_removed_event.wait(wait_time) - assert service_added_count == 2 - assert service_removed_count == 2 + assert service_added_count == 3 + assert service_removed_count == 3 finally: + assert len(zeroconf.listeners) == 1 service_browser.cancel() + assert len(zeroconf.listeners) == 0 zeroconf.remove_all_service_listeners() zeroconf.close() From 3c5d3856e286824611712de13aa0fcbe94e4313f Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 17 Aug 2020 22:23:35 +0200 Subject: [PATCH 059/608] Release version 0.28.1 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 27925eb9..10a0b435 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.1 +====== + +* Fixed a resource leak connected to using ServiceBrowser with multiple types, thanks to + J. Nick Koston. + 0.28.0 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f1f18741..1feeaf87 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.0' +__version__ = '0.28.1' __license__ = 'LGPL' From 0f7366423fab8369700be086f3007c20897fde1f Mon Sep 17 00:00:00 2001 From: Erik Date: Tue, 18 Aug 2020 16:07:03 +0200 Subject: [PATCH 060/608] Remove initial delay before querying for service info --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1feeaf87..895083a9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1894,7 +1894,7 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: """ now = current_time_millis() delay = _LISTENER_TIME - next_ = now + delay + next_ = now last = now + timeout record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)] From fca090db06a0d481ad7f608c4fde3e936ad2f80e Mon Sep 17 00:00:00 2001 From: Paul Daumlechner Date: Wed, 19 Aug 2020 10:33:41 +0200 Subject: [PATCH 061/608] Don't ask already answered questions (#292) Fixes GH-288. Co-authored-by: Erik --- zeroconf/__init__.py | 27 +++--- zeroconf/test.py | 225 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 895083a9..f49c263c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1916,19 +1916,24 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: return False if next_ <= now: out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN), now) - - out.add_question(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN), now) + cached_entry = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if not cached_entry: + out.add_question(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)) + out.add_answer_at_time(cached_entry, now) + cached_entry = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if not cached_entry: + out.add_question(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)) + out.add_answer_at_time(cached_entry, now) if self.server is not None: - out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) - out.add_answer_at_time(zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN), now) - out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN)) - out.add_answer_at_time( - zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN), now - ) + cached_entry = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) + if not cached_entry: + out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) + out.add_answer_at_time(cached_entry, now) + cached_entry = zc.cache.get_by_details(self.name, _TYPE_AAAA, _CLASS_IN) + if not cached_entry: + out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN)) + out.add_answer_at_time(cached_entry, now) zc.send(out) next_ = now + delay delay *= 2 diff --git a/zeroconf/test.py b/zeroconf/test.py index 96be62cf..6b7a31cd 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -9,6 +9,7 @@ import os import socket import struct +import threading import time import unittest from threading import Event @@ -1245,6 +1246,230 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi zeroconf.close() +class TestServiceInfo(unittest.TestCase): + def test_get_info_partial(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # monkey patch the zeroconf send + setattr(zc, "send", send) + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packet()) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for SRV, A, AAAA + last_sent = None + send_event.clear() + zc.handle_response( + mock_incoming_msg( + [r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text)] + ) + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 3 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for A, AAAA + last_sent = None + send_event.clear() + zc.handle_response( + mock_incoming_msg( + [ + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ) + ] + ) + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 2 + assert r.DNSQuestion(service_server, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + last_sent = None + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + zc.handle_response( + mock_incoming_msg( + [ + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ) + ] + ) + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + def test_get_info_single(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # monkey patch the zeroconf send + setattr(zc, "send", send) + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packet()) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + zc.handle_response( + mock_incoming_msg( + [ + r.DNSText( + service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text + ), + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + ] + ) + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + class TestServiceBrowserMultipleTypes(unittest.TestCase): def test_update_record(self): From 3be96b014d61c94d71ae3aa23ba223eead4f4cb7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 22 Aug 2020 08:33:25 -0500 Subject: [PATCH 062/608] Increase test coverage for dns cache --- zeroconf/test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index 6b7a31cd..1046bf77 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -862,6 +862,20 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): cache.remove(record2) assert 'a' not in cache.cache + def test_cache_empty_multiple_calls_does_not_throw(self): + record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + # Ensure multiple removes does not throw + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache + class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): From f64768a7253829f9d8f7796a6a5c8129b92f2aad Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 27 Aug 2020 00:22:14 +0200 Subject: [PATCH 063/608] Release version 0.28.2 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 10a0b435..704c39fe 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.2 +====== + +* Stopped asking questions we already have answers for in cache, thanks to Paul Daumlechner. +* Removed initial delay before querying for service info, thanks to Erik Montnemery. + 0.28.1 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f49c263c..78cafc3f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.1' +__version__ = '0.28.2' __license__ = 'LGPL' From 57d89d85e52dea1f8cb7f6d4b02c0281d5ba0540 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 27 Aug 2020 00:29:57 +0200 Subject: [PATCH 064/608] Reformat using the latest black (20.8b1) --- zeroconf/__init__.py | 4 ++-- zeroconf/test.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 78cafc3f..b9ffa968 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -938,7 +938,7 @@ def add_authorative_answer(self, record: DNSPointer) -> None: self.authorities.append(record) def add_additional_answer(self, record: DNSRecord) -> None: - """ Adds an additional answer + """Adds an additional answer From: RFC 6763, DNS-Based Service Discovery, February 2013 @@ -1136,7 +1136,7 @@ def packets(self) -> List[bytes]: or less in length, except for the case of a single answer which will be written out to a single oversized packet no more than _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP - fragmentation potentially). """ + fragmentation potentially).""" if self.state == self.State.finished: return self.packets_data diff --git a/zeroconf/test.py b/zeroconf/test.py index 1046bf77..33757cd1 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -110,7 +110,14 @@ def test_service_info_dunder(self): name = "xxxyyy" registration_name = "%s.%s" % (name, type_) info = ServiceInfo( - type_, registration_name, 80, 0, 0, b'', "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], + type_, + registration_name, + 80, + 0, + 0, + b'', + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], ) assert not info != info @@ -1765,7 +1772,14 @@ def test_multiple_addresses(): address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) infos = [ ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address_v6], + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[address, address_v6], ), ServiceInfo( type_, From 5a359bb0931fbda8444e30d07a50e59cf4ccca8e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 9 Aug 2020 15:44:54 +0000 Subject: [PATCH 065/608] Reduce the time window that the handlers lock is held Only hold the lock if we have an update. --- zeroconf/__init__.py | 79 ++++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b9ffa968..26770326 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2638,45 +2638,52 @@ def update_record(self, now: float, rec: DNSRecord) -> None: def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" + updates = [] # type: List[Tuple[float, DNSRecord, Optional[DNSRecord]]] + now = current_time_millis() + for record in msg.answers: - with self._handlers_lock: + updated = True - now = current_time_millis() - for record in msg.answers: - - updated = True - - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # Since the cache format is keyed on the lower case record name - # we can avoid iterating everything in the cache and - # only look though entries for the specific name. - # entries_with_name will take care of converting to lowercase - # - # We make a copy of the list that entries_with_name returns - # since we cannot iterate over something we might remove - for entry in self.cache.entries_with_name(record.name).copy(): - - if entry == record: - updated = False - - # Check the time first because it is far cheaper - # than the __eq__ - if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): - self.cache.remove(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) - else: - self.cache.add(record) - if updated: - self.update_record(now, record) + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + # Since the cache format is keyed on the lower case record name + # we can avoid iterating everything in the cache and + # only look though entries for the specific name. + # entries_with_name will take care of converting to lowercase + # + # We make a copy of the list that entries_with_name returns + # since we cannot iterate over something we might remove + for entry in self.cache.entries_with_name(record.name).copy(): + + if entry == record: + updated = False + + # Check the time first because it is far cheaper + # than the __eq__ + if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): + self.cache.remove(entry) + + expired = record.is_expired(now) + maybe_entry = self.cache.get(record) + if not expired: + if maybe_entry is not None: + maybe_entry.reset_ttl(record) else: - if maybe_entry is not None: - self.update_record(now, record) - self.cache.remove(maybe_entry) + self.cache.add(record) + if updated: + updates.append((now, record, None)) + elif maybe_entry is not None: + updates.append((now, record, maybe_entry)) + + if not updates: + return + + # Only hold the lock if we have updates + with self._handlers_lock: + for update in updates: + now, record, entry_to_remove = update + self.update_record(update[0], update[1]) + if entry_to_remove: + self.cache.remove(entry_to_remove) def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if From 0e49aeca6497ede18a3f0c71ea69f2343934ba19 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 31 Aug 2020 12:57:18 +0200 Subject: [PATCH 066/608] Release version 0.28.3 --- README.rst | 5 +++++ zeroconf/__init__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 704c39fe..82464cb3 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,11 @@ See examples directory for more. Changelog ========= +0.28.3 +====== + +* Reduced a time an internal lock is held which should eliminate deadlocks in high-traffic networks. + 0.28.2 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 26770326..279d85f6 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.2' +__version__ = '0.28.3' __license__ = 'LGPL' From 9e27d126d75c73466584c417ab35c1d6cf47ca8b Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 31 Aug 2020 12:58:57 +0200 Subject: [PATCH 067/608] Add an author in the last changelog entry --- README.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 82464cb3..da99ed00 100644 --- a/README.rst +++ b/README.rst @@ -137,7 +137,8 @@ Changelog 0.28.3 ====== -* Reduced a time an internal lock is held which should eliminate deadlocks in high-traffic networks. +* Reduced a time an internal lock is held which should eliminate deadlocks in high-traffic networks, + thanks to J. Nick Koston. 0.28.2 ====== From 1e4aaeaa10c306b9447dacefa03b89ce1e9d7493 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 3 Sep 2020 10:35:58 -0500 Subject: [PATCH 068/608] Avoid copying the entires cache and reduce frequency of Reaper The cache reaper was running at least every 10 seconds, making a copy of the cache, and iterated all the entries to check if they were expired so they could be removed. In practice the reaper was actually running much more frequently because it used self.zc.wait which would unblock any time a record was updated, a listener was added, or when a listener was removed. This change ensures the reaper frequency is only every 10s, and will first attempt to iterate the cache before falling back to making a copy. Previously it made sense to expire the cache more frequently because we had places were we frequently had to enumerate all the cache entries. With #247 and #232 we no longer have to account for this concern. On a mostly idle RPi running HomeAssistant and a busy network the total time spent reaping the cache was more than the total time spent processing the mDNS traffic. Top 10 functions, idle RPi (before) %Own %Total OwnTime TotalTime Function (filename:line) 0.00% 0.00% 2.69s 2.69s handle_read (zeroconf/__init__.py:1367) <== Incoming mDNS 0.00% 0.00% 1.51s 2.98s run (zeroconf/__init__.py:1431) <== Reaper 0.00% 0.00% 1.42s 1.42s is_expired (zeroconf/__init__.py:502) <== Reaper 0.00% 0.00% 1.12s 1.12s entries (zeroconf/__init__.py:1274) <== Reaper 0.00% 0.00% 0.620s 0.620s do_execute (sqlalchemy/engine/default.py:593) 0.00% 0.00% 0.620s 0.620s read_utf (zeroconf/__init__.py:837) 0.00% 0.00% 0.610s 0.610s do_commit (sqlalchemy/engine/default.py:546) 0.00% 0.00% 0.540s 1.16s read_name (zeroconf/__init__.py:853) 0.00% 0.00% 0.380s 0.380s do_close (sqlalchemy/engine/default.py:549) 0.00% 0.00% 0.340s 0.340s write (asyncio/selector_events.py:908) After this change, the Reaper code paths do not show up in the top 10 function sample. %Own %Total OwnTime TotalTime Function (filename:line) 4.00% 4.00% 2.72s 2.72s handle_read (zeroconf/__init__.py:1378) <== Incoming mDNS 4.00% 4.00% 1.81s 1.81s read_utf (zeroconf/__init__.py:837) 1.00% 5.00% 1.68s 3.51s read_name (zeroconf/__init__.py:853) 0.00% 0.00% 1.32s 1.32s do_execute (sqlalchemy/engine/default.py:593) 0.00% 0.00% 0.960s 0.960s readinto (socket.py:669) 0.00% 0.00% 0.950s 0.950s create_connection (urllib3/util/connection.py:74) 0.00% 0.00% 0.910s 0.910s do_commit (sqlalchemy/engine/default.py:546) 1.00% 1.00% 0.880s 0.880s write (asyncio/selector_events.py:908) 0.00% 0.00% 0.700s 0.810s __eq__ (zeroconf/__init__.py:606) 2.00% 2.00% 0.670s 0.670s unpack (zeroconf/__init__.py:737) --- zeroconf/__init__.py | 54 +++++++++++++++++++++++++++++++++++--------- zeroconf/test.py | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 279d85f6..0e8b2a9f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -35,7 +35,7 @@ import time import warnings from collections import OrderedDict -from typing import Dict, List, Optional, Sequence, Union, cast +from typing import Dict, Iterable, List, Optional, Sequence, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints import ifaddr @@ -1268,10 +1268,21 @@ def entries(self) -> List[DNSRecord]: """Returns a list of all entries""" if not self.cache: return [] - else: - # avoid size change during iteration by copying the cache - values = list(self.cache.values()) - return list(itertools.chain.from_iterable(values)) + + # avoid size change during iteration by copying the cache + return list(itertools.chain.from_iterable(list(self.cache.values()))) + + def iterable_entries(self) -> Iterable[DNSRecord]: + """Returns an iterable of all entries. + + This function is provided to avoid copying + the entries but is not threadsafe as the + contents of the cache can change during iteration. + + Callers should trap RuntimeError and fallback + to calling entries. + """ + return itertools.chain.from_iterable(self.cache.values()) class Engine(threading.Thread): @@ -1422,15 +1433,29 @@ def __init__(self, zc: 'Zeroconf') -> None: self.name = "zeroconf-Reaper_%s" % (getattr(self, 'native_id', self.ident),) def run(self) -> None: + """Perodic removal of expired entries from the cache.""" while True: - self.zc.wait(10 * 1000) + with self.zc.reaper_condition: + self.zc.reaper_condition.wait(10) + if self.zc.done: return - now = current_time_millis() - for record in self.zc.cache.entries(): - if record.is_expired(now): - self.zc.update_record(now, record) - self.zc.cache.remove(record) + try: + # We try to iterate the cache without copying the whole + # cache as this can be quite an expensive operation. + self._cleanup_cache(self.zc.cache.iterable_entries()) + except RuntimeError: + # If the cache changes during iteration, we fallback + # to making a copy before iteraiton. + self._cleanup_cache(self.zc.cache.entries()) + + def _cleanup_cache(self, entries: Iterable[DNSRecord]) -> None: + """Remove expired entries from the cache.""" + now = current_time_millis() + for record in entries: + if record.is_expired(now): + self.zc.update_record(now, record) + self.zc.cache.remove(record) class Signal: @@ -2342,6 +2367,7 @@ def __init__( self.cache = DNSCache() self.condition = threading.Condition() + self.reaper_condition = threading.Condition() # Ensure we create the lock before # we add the listener as we could get @@ -2373,6 +2399,11 @@ def notify_all(self) -> None: with self.condition: self.condition.notify_all() + def notify_reaper(self) -> None: + """Notifies reaper""" + with self.reaper_condition: + self.reaper_condition.notify_all() + def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, @@ -2878,6 +2909,7 @@ def close(self) -> None: # shutdown the rest self.notify_all() + self.notify_reaper() self.reaper.join() for s in self._respond_sockets: s.close() diff --git a/zeroconf/test.py b/zeroconf/test.py index 33757cd1..e4ee1bfa 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -884,6 +884,54 @@ def test_cache_empty_multiple_calls_does_not_throw(self): assert 'a' not in cache.cache +class TestReaper(unittest.TestCase): + def test_reaper(self): + zeroconf = Zeroconf(interfaces=['127.0.0.1']) + original_entries = zeroconf.cache.entries() + record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + zeroconf.cache.add(record_with_10s_ttl) + zeroconf.cache.add(record_with_1s_ttl) + entries_with_cache = zeroconf.cache.entries() + time.sleep(1.05) + zeroconf.notify_reaper() + time.sleep(0.05) + entries = zeroconf.cache.entries() + + try: + iterable_entries = list(zeroconf.cache.iterable_entries()) + finally: + zeroconf.close() + + assert entries != original_entries + assert entries_with_cache != original_entries + assert record_with_10s_ttl in entries + assert record_with_1s_ttl not in entries + assert record_with_10s_ttl in iterable_entries + assert record_with_1s_ttl not in iterable_entries + + def test_reaper_with_dict_change_during_iteration(self): + zeroconf = Zeroconf(interfaces=['127.0.0.1']) + original_entries = zeroconf.cache.entries() + record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + zeroconf.cache.add(record_with_10s_ttl) + zeroconf.cache.add(record_with_1s_ttl) + entries_with_cache = zeroconf.cache.entries() + with unittest.mock.patch("zeroconf.DNSCache.iterable_entries", side_effect=RuntimeError): + time.sleep(1.05) + zeroconf.notify_reaper() + time.sleep(0.05) + + entries = zeroconf.cache.entries() + zeroconf.close() + + assert entries != original_entries + assert entries_with_cache != original_entries + assert record_with_10s_ttl in entries + assert record_with_1s_ttl not in entries + + class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): From 0265a9d57630a4a19bcd3638a6bb3f4b18eba01b Mon Sep 17 00:00:00 2001 From: Justin Nesselrotte Date: Sun, 6 Sep 2020 14:05:18 -0600 Subject: [PATCH 069/608] Add ServiceListener to __all__ for Zeroconf module (#298) It's part of the public API. --- zeroconf/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0e8b2a9f..664aa554 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -51,6 +51,7 @@ "Zeroconf", "ServiceInfo", "ServiceBrowser", + "ServiceListener", "Error", "InterfaceChoice", "ServiceStateChange", From fb876d6013979cdaa8c0ddebe81e7520e9ee8cc9 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 6 Sep 2020 22:10:07 +0200 Subject: [PATCH 070/608] Release version 0.28.4 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index da99ed00..442ebf43 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.4 +====== + +* Improved cache reaper performance significantly, thanks to J. Nick Koston. +* Added ServiceListener to __all__ as it's part of the public API, thanks to Justin Nesselrotte. + 0.28.3 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 664aa554..3fa9b15a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.3' +__version__ = '0.28.4' __license__ = 'LGPL' From 1f81e0bcad1cae735ba532758d167368925c8ede Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 9 Sep 2020 18:02:54 +0200 Subject: [PATCH 071/608] Test with the development version of Python 3.9 (#300) There've been reports of test failures on Python 3.9, let's verify this. Allowing failures for now until it goes stable. --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index c5369538..6937b7cd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,12 @@ python: - "3.6" - "3.7" - "3.8" + - "3.9-dev" - "pypy3.5" - "pypy3" +matrix: + allow_failures: + - python: "3.9-dev" install: - pip install --upgrade -r requirements-dev.txt # mypy can't be installed on pypy From f3219326e65f4410d45ace05f88082354a2f7525 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 10 Sep 2020 03:36:09 -0500 Subject: [PATCH 072/608] Ignore duplicate messages (#299) When watching packet captures, I noticed that zeroconf was processing incoming data 3x on a my Home Assistant OS install because there are three interfaces. We can skip processing duplicate packets in order to reduce the overhead of decoding data we have already processed. Before Idle cpu ~8.3% recvfrom 4 times 267 recvfrom(7, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("192.168.210.102")}, [16]) = 71 267 recvfrom(7, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("172.30.32.1")}, [16]) = 71 267 recvfrom(8, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("192.168.210.102")}, [16]) = 71 267 recvfrom(8, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("172.30.32.1")}, [16]) = 71 sendto 8 times 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 267 sendto(8, "\0\0\204\0\0\0\0\1\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300K\0\1\200\1\0\0\0x\0\4\300\250\325\232", 335, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 335 After Idle cpu ~4.1% recvfrom 4 times (no change): 267 recvfrom(7, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("192.168.210.102")}, [16]) = 71 267 recvfrom(9, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("192.168.210.102")}, [16]) = 71 267 recvfrom(7, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("172.30.32.1")}, [16]) = 71 267 recvfrom(9, "\0\0\204\0\0\0\0\1\0\0\0\0\v_esphomelib\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\31\26masterbed_tvcabinet_32\300\f", 8966, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("172.30.32.1")}, [16]) = 71 sendto 2 times (reduced by 4x): 267 sendto(9, "\0\0\204\0\0\0\0\2\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\t_services\7_dns-sd\4_udp\300!\0\f\0\1\0\0\21\224\0\2\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300p\0\1\200\1\0\0\0x\0\4\300\250\325\232", 372, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 372 267 sendto(9, "\0\0\204\0\0\0\0\2\0\0\0\3\17_home-assistant\4_tcp\5local\0\0\f\0\1\0\0\21\224\0\7\4Home\300\f\t_services\7_dns-sd\4_udp\300!\0\f\0\1\0\0\21\224\0\2\300\f\3002\0!\200\1\0\0\0x\0)\0\0\0\0\37\273 66309dfc726446799c8a2c0f1cb0480f\300!\3002\0\20\200\1\0\0\21\224\0\305\22location_name=Home%uuid=66309dfc726446799c8a2c0f1cb0480f\24version=0.116.0.dev0\rexternal_url=(internal_url=http://192.168.213.154:8123$base_url=http://192.168.213.154:8123\32requires_api_password=True\300p\0\1\200\1\0\0\0x\0\4\300\250\325\232", 372, 0, {sa_family=AF_INET, sin_port=htons(5353), sin_addr=inet_addr("224.0.0.251")}, 16) = 372 With debug logging on for ~5 minutes bash-5.0# grep 'Received from' home-assistant.log |wc 11458 499196 19706165 bash-5.0# grep 'Ignoring' home-assistant.log |wc 9357 210562 9299687 --- zeroconf/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3fa9b15a..ad4855c1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1381,6 +1381,17 @@ def handle_read(self, socket_: socket.socket) -> None: self.log_exception_warning('Error reading from socket %d', socket_.fileno()) return + if self.data == data: + log.debug( + 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', + addr, + port, + socket_.fileno(), + len(data), + data, + ) + return + self.data = data msg = DNSIncoming(data) if msg.valid: From 2db7fff033937a929cdfee1fc7c93c594872799e Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 11 Sep 2020 03:10:50 +0200 Subject: [PATCH 073/608] Fix AttributeError: module 'unittest' has no attribute 'mock' (#302) We only had module-level unittest import before now, but code accessing mock through unittest.mock was working because we have a test-level import from unittest.mock which causes unittest to gain the mock attribute and if the test was run before other tests (those using unittest.mock.patch) all was good. If the test was not run before them, though, they'd fail. Closes GH-295. --- zeroconf/test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index e4ee1bfa..2b1fb908 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -12,6 +12,7 @@ import threading import time import unittest +import unittest.mock from threading import Event from typing import Dict, Optional # noqa # used in type hints from typing import cast From eda1b3dd17329c40a59b628b4bbca15c42af43b7 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Fri, 11 Sep 2020 03:14:09 +0200 Subject: [PATCH 074/608] Release version 0.28.5 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 442ebf43..f0c9da16 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.5 +====== + +* Enabled ignoring duplicated messages which decreases CPU usage, thanks to J. Nick Koston. +* Fixed spurious AttributeError: module 'unittest' has no attribute 'mock' in tests. + 0.28.4 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ad4855c1..d0a922d9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.4' +__version__ = '0.28.5' __license__ = 'LGPL' From 6ab0cd0a0446f158a1d8a64a3bc548cf9e103179 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 11 Oct 2020 06:43:48 -0500 Subject: [PATCH 075/608] Loosen validation to ensure get_service_info can handle production devices (#307) Validation of names was too strict and rejected devices that are otherwise functional. A partial list of devices that unexpectedly triggered a BadTypeInNameException: Bose Soundtouch Yeelights Rachio Sprinklers iDevices --- zeroconf/__init__.py | 90 ++++++++++++++++++++++++++++---------------- zeroconf/test.py | 56 ++++++++++++++++++++------- 2 files changed, 101 insertions(+), 45 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d0a922d9..d35df9d0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -179,6 +179,10 @@ _EXPIRE_STALE_TIME_PERCENT = 50 _EXPIRE_REFRESH_TIME_PERCENT = 75 +_LOCAL_TRAILER = '.local.' +_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' +_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' + try: _IPPROTO_IPV6 = socket.IPPROTO_IPV6 except AttributeError: @@ -229,7 +233,7 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) -def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: +def service_type_name(type_: str, *, allow_underscores: bool = False, strict: bool = True) -> str: """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -246,9 +250,11 @@ def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: This is true because we are implementing mDNS and since the 'm' means multi-cast, the 'local.' domain is mandatory. - 2) local is preceded with either '_udp.' or '_tcp.' + 2) local is preceded with either '_udp.' or '_tcp.' unless + strict is False - 3) service name precedes <_tcp|_udp> + 3) service name precedes <_tcp|_udp> unless + strict is False The rules for Service Names [RFC6335] state that they may be no more than fifteen characters long (not counting the mandatory underscore), @@ -269,45 +275,65 @@ def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: :param type_: Type, SubType or service name to validate :return: fully qualified service name (eg: _http._tcp.local.) """ - if not (type_.endswith('._tcp.local.') or type_.endswith('._udp.local.')): - raise BadTypeInNameException("Type '%s' must end with '._tcp.local.' or '._udp.local.'" % type_) - remaining = type_[: -len('._tcp.local.')].split('.') - name = remaining.pop() - if not name: - raise BadTypeInNameException("No Service name found") + if type_.endswith(_TCP_PROTOCOL_LOCAL_TRAILER) or type_.endswith(_NONTCP_PROTOCOL_LOCAL_TRAILER): + remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] + has_protocol = True + elif strict: + raise BadTypeInNameException( + "Type '%s' must end with '%s' or '%s'" + % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER) + ) + elif type_.endswith(_LOCAL_TRAILER): + remaining = type_[: -len(_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_LOCAL_TRAILER) + 1 :] + has_protocol = False + else: + raise BadTypeInNameException("Type '%s' must end with '%s'" % (type_, _LOCAL_TRAILER)) - if len(remaining) == 1 and len(remaining[0]) == 0: - raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) + if strict or has_protocol: + service_name = remaining.pop() + if not service_name: + raise BadTypeInNameException("No Service name found") - if name[0] != '_': - raise BadTypeInNameException("Service name (%s) must start with '_'" % name) + if len(remaining) == 1 and len(remaining[0]) == 0: + raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) - # remove leading underscore - name = name[1:] + if service_name[0] != '_': + raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name) - if len(name) > 15: - raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % name) + test_service_name = service_name[1:] - if '--' in name: - raise BadTypeInNameException("Service name (%s) must not contain '--'" % name) + if len(test_service_name) > 15: + raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) - if '-' in (name[0], name[-1]): - raise BadTypeInNameException("Service name (%s) may not start or end with '-'" % name) + if '--' in test_service_name: + raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name) - if not _HAS_A_TO_Z.search(name): - raise BadTypeInNameException("Service name (%s) must contain at least one letter (eg: 'A-Z')" % name) + if '-' in (test_service_name[0], test_service_name[-1]): + raise BadTypeInNameException( + "Service name (%s) may not start or end with '-'" % test_service_name + ) - allowed_characters_re = ( - _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores else _HAS_ONLY_A_TO_Z_NUM_HYPHEN - ) + if not _HAS_A_TO_Z.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name + ) - if not allowed_characters_re.search(name): - raise BadTypeInNameException( - "Service name (%s) must contain only these characters: " - "A-Z, a-z, 0-9, hyphen ('-')%s" % (name, ", underscore ('_')" if allow_underscores else "") + allowed_characters_re = ( + _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores else _HAS_ONLY_A_TO_Z_NUM_HYPHEN ) + if not allowed_characters_re.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain only these characters: " + "A-Z, a-z, 0-9, hyphen ('-')%s" + % (test_service_name, ", underscore ('_')" if allow_underscores else "") + ) + else: + service_name = '' + if remaining and remaining[-1] == '_sub': remaining.pop() if len(remaining) == 0 or len(remaining[0]) == 0: @@ -326,7 +352,7 @@ def service_type_name(type_: str, *, allow_underscores: bool = False) -> str: "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] ) - return '_' + name + type_[-len('._tcp.local.') :] + return service_name + trailer # Exceptions @@ -1770,7 +1796,7 @@ def __init__( # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: raise TypeError("addresses and parsed_addresses cannot be provided together") - if not type_.endswith(service_type_name(name, allow_underscores=True)): + if not type_.endswith(service_type_name(name, strict=False, allow_underscores=True)): raise BadTypeInNameException self.type = type_ self.name = name diff --git a/zeroconf/test.py b/zeroconf/test.py index 2b1fb908..596f666a 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -657,14 +657,43 @@ def test_bad_service_names(self): for name in bad_names_to_try: self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name) + def test_bad_local_names_for_get_service_info(self): + bad_names_to_try = ( + 'homekitdev._nothttp._tcp.local.', + 'homekitdev._http._udp.local.', + ) + for name in bad_names_to_try: + self.assertRaises( + r.BadTypeInNameException, self.browser.get_service_info, '_http._tcp.local.', name + ) + def test_good_instance_names(self): + assert r.service_type_name('.._x._tcp.local.') == '_x._tcp.local.' + assert r.service_type_name('x.sub._http._tcp.local.') == '_http._tcp.local.' + assert ( + r.service_type_name('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') + == '_http._tcp.local.' + ) + + def test_good_instance_names_without_protocol(self): good_names_to_try = ( - '.._x._tcp.local.', - 'x.sub._http._tcp.local.', - '6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.', + "Rachio-C73233.local.", + 'YeelightColorBulb-3AFD.local.', + 'YeelightTunableBulb-7220.local.', + "AlexanderHomeAssistant 74651D.local.", + 'iSmartGate-152.local.', + 'MyQ-FGA.local.', + 'lutron-02c4392a.local.', + 'WICED-hap-3E2734.local.', + 'MyHost.local.', + 'MyHost.sub.local.', ) for name in good_names_to_try: - r.service_type_name(name) + assert r.service_type_name(name, strict=False) == 'local.' + + for name in good_names_to_try: + # Raises without strict=False + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) def test_bad_types(self): bad_names_to_try = ( @@ -687,17 +716,18 @@ def test_bad_sub_types(self): def test_good_service_names(self): good_names_to_try = ( - '_x._tcp.local.', - '_x._udp.local.', - '_12345-67890-abc._udp.local.', - 'x._sub._http._tcp.local.', - 'a' * 63 + '._sub._http._tcp.local.', - 'a' * 61 + u'â._sub._http._tcp.local.', + ('_x._tcp.local.', '_x._tcp.local.'), + ('_x._udp.local.', '_x._udp.local.'), + ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'), + ('x._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 61 + u'â._sub._http._tcp.local.', '_http._tcp.local.'), ) - for name in good_names_to_try: - r.service_type_name(name) - r.service_type_name('_one_two._tcp.local.', allow_underscores=True) + for name, result in good_names_to_try: + assert r.service_type_name(name) == result + + assert r.service_type_name('_one_two._tcp.local.', allow_underscores=True) == '_one_two._tcp.local.' def test_invalid_addresses(self): type_ = "_test-srvc-type._tcp.local." From 6a0c5dd4e84c30264747847e8f1045ece2a14288 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 13 Oct 2020 20:06:40 +0200 Subject: [PATCH 076/608] Merge strict and allow_underscores (#309) Those really serve the same purpose -- are we receiving data (and want to be flexible) or registering services (and want to be strict). --- zeroconf/__init__.py | 11 +++++------ zeroconf/test.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d35df9d0..30b26a2c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -233,7 +233,7 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) -def service_type_name(type_: str, *, allow_underscores: bool = False, strict: bool = True) -> str: +def service_type_name(type_: str, *, strict: bool = True) -> str: """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -322,14 +322,13 @@ def service_type_name(type_: str, *, allow_underscores: bool = False, strict: bo ) allowed_characters_re = ( - _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE if allow_underscores else _HAS_ONLY_A_TO_Z_NUM_HYPHEN + _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE ) if not allowed_characters_re.search(test_service_name): raise BadTypeInNameException( "Service name (%s) must contain only these characters: " - "A-Z, a-z, 0-9, hyphen ('-')%s" - % (test_service_name, ", underscore ('_')" if allow_underscores else "") + "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')") ) else: service_name = '' @@ -1564,7 +1563,7 @@ def __init__( assert handlers or listener, 'You need to specify at least one handler' self.types = set(type_ if isinstance(type_, list) else [type_]) for check_type_ in self.types: - if not check_type_.endswith(service_type_name(check_type_, allow_underscores=True)): + if not check_type_.endswith(service_type_name(check_type_, strict=False)): raise BadTypeInNameException threading.Thread.__init__(self) self.daemon = True @@ -1796,7 +1795,7 @@ def __init__( # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: raise TypeError("addresses and parsed_addresses cannot be provided together") - if not type_.endswith(service_type_name(name, strict=False, allow_underscores=True)): + if not type_.endswith(service_type_name(name, strict=False)): raise BadTypeInNameException self.type = type_ self.name = name diff --git a/zeroconf/test.py b/zeroconf/test.py index 596f666a..9d1544e7 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -727,7 +727,7 @@ def test_good_service_names(self): for name, result in good_names_to_try: assert r.service_type_name(name) == result - assert r.service_type_name('_one_two._tcp.local.', allow_underscores=True) == '_one_two._tcp.local.' + assert r.service_type_name('_one_two._tcp.local.', strict=False) == '_one_two._tcp.local.' def test_invalid_addresses(self): type_ = "_test-srvc-type._tcp.local." From 474442750d5d529436a118fda98a0b5f4680dc4d Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Tue, 13 Oct 2020 20:09:25 +0200 Subject: [PATCH 077/608] Release version 0.28.6 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f0c9da16..e1581b41 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.6 +====== + +* Loosened service name validation when receiving from the network this lets us handle + some real world devices previously causing errors, thanks to J. Nick Koston. + 0.28.5 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 30b26a2c..40bacfa0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.5' +__version__ = '0.28.6' __license__ = 'LGPL' From 4da1612b728acbcf2ab0c4bee09891c46f387bfb Mon Sep 17 00:00:00 2001 From: Alexey Vazhnov Date: Mon, 19 Oct 2020 22:42:10 +0000 Subject: [PATCH 078/608] Restore IPv6 addresses output Before this change, script `examples/browser.py` printed IPv4 only, even with `--v6` argument. With this change, `examples/browser.py` prints both IPv4 + IPv6 by default, and IPv6 only with `--v6-only` argument. I took the idea from the fork https://github.com/ad3angel1s/python-zeroconf/blob/master/examples/browser.py --- examples/browser.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/browser.py b/examples/browser.py index 624aab9f..2f264439 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -7,7 +7,6 @@ import argparse import logging -import socket from time import sleep from typing import cast @@ -23,7 +22,7 @@ def on_service_state_change( info = zeroconf.get_service_info(service_type, name) print("Info from zeroconf.get_service_info: %r" % (info)) if info: - addresses = ["%s:%d" % (socket.inet_ntoa(addr), cast(int, info.port)) for addr in info.addresses] + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) print(" Server: %s" % (info.server,)) From 41368588e5fcc6ec9596f306e39e2eaac2a9ec18 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 25 Oct 2020 11:44:05 -1000 Subject: [PATCH 079/608] Prevent crash when a service is added or removed during handle_response Services are now modified under a lock. The service processing is now done in a try block to ensure RuntimeError is caught which prevents the zeroconf engine from unexpectedly terminating. --- zeroconf/__init__.py | 144 +++++++++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 66 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 40bacfa0..d9faa960 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2410,6 +2410,7 @@ def __init__( # we add the listener as we could get # a message before the lock is created. self._handlers_lock = threading.Lock() # ensure we process a full message in one go + self._service_lock = threading.Lock() # add and remove services thread safe self.engine = Engine(self) self.listener = Listener(self) @@ -2487,11 +2488,12 @@ def register_service( info.host_ttl = ttl info.other_ttl = ttl self.check_service(info, allow_name_change, cooperating_responders) - self.services[info.name.lower()] = info - if info.type in self.servicetypes: - self.servicetypes[info.type] += 1 - else: - self.servicetypes[info.type] = 1 + with self._service_lock: + self.services[info.name.lower()] = info + if info.type in self.servicetypes: + self.servicetypes[info.type] += 1 + else: + self.servicetypes[info.type] = 1 self._broadcast_service(info) @@ -2500,9 +2502,10 @@ def update_service(self, info: ServiceInfo) -> None: Zeroconf will then respond to requests for information for that service.""" - assert self.services[info.name.lower()] is not None + with self._service_lock: + assert self.services[info.name.lower()] is not None - self.services[info.name.lower()] = info + self.services[info.name.lower()] = info self._broadcast_service(info) @@ -2546,14 +2549,12 @@ def _broadcast_service(self, info: ServiceInfo) -> None: def unregister_service(self, info: ServiceInfo) -> None: """Unregister a service.""" - try: + with self._service_lock: del self.services[info.name.lower()] if self.servicetypes[info.type] > 1: self.servicetypes[info.type] -= 1 else: del self.servicetypes[info.type] - except Exception as e: # TODO stop catching all Exceptions - log.exception('Unknown error, possibly benign: %r', e) now = current_time_millis() next_time = now i = 0 @@ -2600,7 +2601,7 @@ def unregister_all_services(self) -> None: now = current_time_millis() continue out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - for info in self.services.values(): + for info in list(self.services.values()): out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) out.add_answer_at_time( DNSService( @@ -2766,69 +2767,77 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None out.add_question(question) for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name == "_services._dns-sd._udp.local.": - for stype in self.servicetypes.keys(): - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer( - "_services._dns-sd._udp.local.", _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype - ), - ) - for service in self.services.values(): - if question.name == service.type: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name), - ) - - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer( - DNSService( - service.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, + try: + if question.type == _TYPE_PTR: + if question.name == "_services._dns-sd._udp.local.": + for stype in self.servicetypes.keys(): + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + out.add_answer( + msg, + DNSPointer( + "_services._dns-sd._udp.local.", + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, + ), ) - ) - out.add_additional_answer( - DNSText( - service.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, + for service in self.services.values(): + if question.name == service.type: + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + out.add_answer( + msg, + DNSPointer( + service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name + ), ) - ) - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. out.add_additional_answer( - DNSAddress( - service.server, - type_, + DNSService( + service.name, + _TYPE_SRV, _CLASS_IN | _CLASS_UNIQUE, service.host_ttl, - address, + service.priority, + service.weight, + cast(int, service.port), + service.server, ) ) - else: - try: + out.add_additional_answer( + DNSText( + service.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + service.other_ttl, + service.text, + ) + ) + for address in service.addresses_by_version(IPVersion.All): + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_additional_answer( + DNSAddress( + service.server, + type_, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + address, + ) + ) + else: if out is None: out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know if question.type in (_TYPE_A, _TYPE_ANY): for service in self.services.values(): - if service.server == question.name.lower(): + if service.server == name_to_find: for address in service.addresses_by_version(IPVersion.All): type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A out.add_answer( @@ -2842,10 +2851,9 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None ), ) - name_to_find = question.name.lower() - if name_to_find not in self.services: + service = self.services.get(name_to_find) # type: ignore + if service is None: continue - service = self.services[name_to_find] if question.type in (_TYPE_SRV, _TYPE_ANY): out.add_answer( @@ -2884,8 +2892,12 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None address, ) ) - except Exception: # TODO stop catching all Exceptions - self.log_exception_warning() + except Exception: # TODO stop catching all Exceptions + # RuntimeError is expected because the service + # could be added/removed while iterating services + # and we cannot lock here because it would be too + # expensive. + self.log_exception_warning() if out is not None and out.answers: out.id = msg.id From 2708fef6052f7e6e6eb36a157438b316e6d38b21 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 31 Oct 2020 03:36:22 -1000 Subject: [PATCH 080/608] Refactor to move service registration into a registry This permits removing the broad exception catch that was expanded to avoid a crash in when adding or removing a service --- zeroconf/__init__.py | 471 ++++++++++++++++++++++--------------------- zeroconf/test.py | 37 ++++ 2 files changed, 283 insertions(+), 225 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d9faa960..9fd70b5c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -381,6 +381,10 @@ class BadTypeInNameException(Error): pass +class ServiceNameAlreadyRegistered(Error): + pass + + # implementation classes @@ -2341,6 +2345,96 @@ def can_send_to(sock: socket.socket, address: str) -> bool: return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) +class ServiceRegistry: + """A registry to keep track of services. + + This class exists to ensure services can + be safely added and removed with thread + safety. + """ + + def __init__( + self, + ) -> None: + """Create the ServiceRegistry class.""" + self.services = {} # type: Dict[str, ServiceInfo] + self.types = {} # type: Dict[str, List] + self.servers = {} # type: Dict[str, List] + self._lock = threading.Lock() # add and remove services thread safe + + def add(self, info: ServiceInfo) -> None: + """Add a new service to the registry.""" + + with self._lock: + self._add(info) + + def remove(self, info: ServiceInfo) -> None: + """Remove a new service from the registry.""" + + with self._lock: + self._remove(info) + + def update(self, info: ServiceInfo) -> None: + """Update new service in the registry.""" + + with self._lock: + self._remove(info) + self._add(info) + + def get_service_infos(self) -> List[ServiceInfo]: + """Return all ServiceInfo.""" + return list(self.services.values()) + + def get_info_name(self, name: str) -> Optional[ServiceInfo]: + """Return all ServiceInfo for the name.""" + return self.services.get(name) + + def get_types(self) -> List[str]: + """Return all types.""" + return list(self.types.keys()) + + def get_infos_type(self, type_: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching type.""" + return self._get_by_index("types", type_) + + def get_infos_server(self, server: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching server.""" + return self._get_by_index("servers", server) + + def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching the index.""" + service_infos = [] + + for name in getattr(self, attr).get(key, [])[:]: + info = self.services.get(name) + # Since we do not get under a lock since it would be + # a performance issue, its possible + # the service can be unregistered during the get + # so we must check if info is None + if info is not None: + service_infos.append(info) + + return service_infos + + def _add(self, info: ServiceInfo) -> None: + """Add a new service under the lock.""" + lower_name = info.name.lower() + if lower_name in self.services: + raise ServiceNameAlreadyRegistered + + self.services[lower_name] = info + self.types.setdefault(info.type, []).append(lower_name) + self.servers.setdefault(info.server, []).append(lower_name) + + def _remove(self, info: ServiceInfo) -> None: + """Remove a service under the lock.""" + lower_name = info.name.lower() + old_service_info = self.services[lower_name] + self.types[old_service_info.type].remove(lower_name) + self.servers[old_service_info.server].remove(lower_name) + del self.services[lower_name] + + class Zeroconf(QuietLogger): """Implementation of Zeroconf Multicast DNS Service Discovery @@ -2398,8 +2492,7 @@ def __init__( self.listeners = [] # type: List[RecordUpdateListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] - self.services = {} # type: Dict[str, ServiceInfo] - self.servicetypes = {} # type: Dict[str, int] + self.registry = ServiceRegistry() self.cache = DNSCache() @@ -2410,7 +2503,6 @@ def __init__( # we add the listener as we could get # a message before the lock is created. self._handlers_lock = threading.Lock() # ensure we process a full message in one go - self._service_lock = threading.Lock() # add and remove services thread safe self.engine = Engine(self) self.listener = Listener(self) @@ -2488,29 +2580,18 @@ def register_service( info.host_ttl = ttl info.other_ttl = ttl self.check_service(info, allow_name_change, cooperating_responders) - with self._service_lock: - self.services[info.name.lower()] = info - if info.type in self.servicetypes: - self.servicetypes[info.type] += 1 - else: - self.servicetypes[info.type] = 1 - - self._broadcast_service(info) + self.registry.add(info) + self._broadcast_service(info, _REGISTER_TIME, None) def update_service(self, info: ServiceInfo) -> None: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service.""" - with self._service_lock: - assert self.services[info.name.lower()] is not None - - self.services[info.name.lower()] = info - - self._broadcast_service(info) - - def _broadcast_service(self, info: ServiceInfo) -> None: + self.registry.update(info) + self._broadcast_service(info, _REGISTER_TIME, None) + def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: now = current_time_millis() next_time = now i = 0 @@ -2519,42 +2600,51 @@ def _broadcast_service(self, info: ServiceInfo) -> None: self.wait(next_time - now) now = current_time_millis() continue + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - info.host_ttl, - info.priority, - info.weight, - cast(int, info.port), - info.server, - ), - 0, - ) + self._add_broadcast_answer(out, info, ttl) + self.send(out) + i += 1 + next_time += interval + + def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]) -> None: + """Add answers to broadcast a service.""" + other_ttl = info.other_ttl if override_ttl is None else override_ttl + host_ttl = info.host_ttl if override_ttl is None else override_ttl + out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, other_ttl, info.name), 0) + out.add_answer_at_time( + DNSService( + info.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + host_ttl, + info.priority, + info.weight, + cast(int, info.port), + info.server, + ), + 0, + ) + out.add_answer_at_time( + DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, other_ttl, info.text), 0 + ) + for address in info.addresses_by_version(IPVersion.All): + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A out.add_answer_at_time( - DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, info.other_ttl, info.text), 0 + DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, host_ttl, address), 0 ) - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time( - DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, info.host_ttl, address), 0 - ) - self.send(out) - i += 1 - next_time += _REGISTER_TIME def unregister_service(self, info: ServiceInfo) -> None: """Unregister a service.""" - with self._service_lock: - del self.services[info.name.lower()] - if self.servicetypes[info.type] > 1: - self.servicetypes[info.type] -= 1 - else: - del self.servicetypes[info.type] + self.registry.remove(info) + self._broadcast_service(info, _UNREGISTER_TIME, 0) + + def unregister_all_services(self) -> None: + """Unregister all registered services.""" + service_infos = self.registry.get_service_infos() + if not service_infos: + return now = current_time_millis() next_time = now i = 0 @@ -2564,70 +2654,12 @@ def unregister_service(self, info: ServiceInfo) -> None: now = current_time_millis() continue out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - 0, - info.priority, - info.weight, - cast(int, info.port), - info.name, - ), - 0, - ) - out.add_answer_at_time(DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, 0, info.text), 0) - - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time( - DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, 0, address), 0 - ) + for info in service_infos: + self._add_broadcast_answer(out, info, 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME - def unregister_all_services(self) -> None: - """Unregister all registered services.""" - if len(self.services) > 0: - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - for info in list(self.services.values()): - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, 0, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - 0, - info.priority, - info.weight, - cast(int, info.port), - info.server, - ), - 0, - ) - out.add_answer_at_time( - DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, 0, info.text), 0 - ) - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time( - DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, 0, address), 0 - ) - self.send(out) - i += 1 - next_time += _UNREGISTER_TIME - def check_service( self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False ) -> None: @@ -2767,137 +2799,126 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None out.add_question(question) for question in msg.questions: - try: - if question.type == _TYPE_PTR: - if question.name == "_services._dns-sd._udp.local.": - for stype in self.servicetypes.keys(): - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer( - "_services._dns-sd._udp.local.", - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - for service in self.services.values(): - if question.name == service.type: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer( - service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name - ), - ) - - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer( - DNSService( - service.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, - ) - ) - out.add_additional_answer( - DNSText( - service.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ) - ) - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( - DNSAddress( - service.server, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ) - ) - else: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - name_to_find = question.name.lower() - - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.services.values(): - if service.server == name_to_find: - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer( - msg, - DNSAddress( - question.name, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ), - ) - - service = self.services.get(name_to_find) # type: ignore - if service is None: - continue - - if question.type in (_TYPE_SRV, _TYPE_ANY): + if question.type == _TYPE_PTR: + if question.name == "_services._dns-sd._udp.local.": + for stype in self.registry.get_types(): + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) out.add_answer( msg, - DNSService( - question.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, + DNSPointer( + "_services._dns-sd._udp.local.", + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, ), ) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer( - msg, - DNSText( - question.name, - _TYPE_TXT, + for service in self.registry.get_infos_type(question.name): + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + out.add_answer( + msg, + DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name), + ) + + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + out.add_additional_answer( + DNSService( + service.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + service.priority, + service.weight, + cast(int, service.port), + service.server, + ) + ) + out.add_additional_answer( + DNSText( + service.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + service.other_ttl, + service.text, + ) + ) + for address in service.addresses_by_version(IPVersion.All): + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_additional_answer( + DNSAddress( + service.server, + type_, _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ), + service.host_ttl, + address, + ) ) - if question.type == _TYPE_SRV: + else: + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + + name_to_find = question.name.lower() + + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): for address in service.addresses_by_version(IPVersion.All): type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( + out.add_answer( + msg, DNSAddress( - service.server, + question.name, type_, _CLASS_IN | _CLASS_UNIQUE, service.host_ttl, address, - ) + ), ) - except Exception: # TODO stop catching all Exceptions - # RuntimeError is expected because the service - # could be added/removed while iterating services - # and we cannot lock here because it would be too - # expensive. - self.log_exception_warning() + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + continue + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer( + msg, + DNSService( + question.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + service.priority, + service.weight, + cast(int, service.port), + service.server, + ), + ) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer( + msg, + DNSText( + question.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + service.other_ttl, + service.text, + ), + ) + if question.type == _TYPE_SRV: + for address in service.addresses_by_version(IPVersion.All): + type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A + out.add_additional_answer( + DNSAddress( + service.server, + type_, + _CLASS_IN | _CLASS_UNIQUE, + service.host_ttl, + address, + ) + ) if out is not None and out.answers: out.id = msg.id diff --git a/zeroconf/test.py b/zeroconf/test.py index 9d1544e7..cfbd2a79 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -878,6 +878,43 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): nbr_answers = nbr_additionals = nbr_authorities = 0 +class TestServiceRegistry(unittest.TestCase): + def test_only_register_once(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) + registry.remove(info) + registry.add(info) + + def test_lookups(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + + assert registry.get_service_infos() == [info] + assert registry.get_info_name(registration_name) == info + assert registry.get_infos_type(type_) == [info] + assert registry.get_infos_server("ash-2.local.") == [info] + assert registry.get_types() == [type_] + + class TestDNSCache(unittest.TestCase): def test_order(self): record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') From 8f7effd2f89c542162d0e5ac257c561501690d16 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Sun, 13 Dec 2020 02:08:20 +0100 Subject: [PATCH 081/608] Release version 0.28.7 --- README.rst | 7 +++++++ zeroconf/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index e1581b41..0bf8fcdc 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,13 @@ See examples directory for more. Changelog ========= +0.28.7 +====== + +* Fixed the IPv6 address rendering in the browser example, thanks to Alexey Vazhnov. +* Fixed a crash happening when a service is added or removed during handle_response + and improved exception handling, thanks to J. Nick Koston. + 0.28.6 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9fd70b5c..50359458 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.6' +__version__ = '0.28.7' __license__ = 'LGPL' From 86b4e11434d44e2f9a42354109a10f601c44d66a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Dec 2020 13:50:16 -1000 Subject: [PATCH 082/608] Ensure the name cache is rolled back when the packet reaches maximum size If the packet was too large, it would be rolled back at the end of write_record. We need to remove the names that were added to the name cache (self.names) as well to avoid a case were we would create a pointer to a name that was rolled back. The size of the packet was incorrect at the end after the inserts because insert_short would increase self.size even though it was already accounted before. To resolve this insert_short_at_start was added which does not increase self.size. This did not cause an actual bug, however it sure made debugging this problem far more difficult. Additionally the size now inserted and then replaced when the actual size is known because it made debugging quite difficult since the size did not previously agree with the data. --- zeroconf/__init__.py | 43 ++++++------ zeroconf/test.py | 153 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+), 19 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 50359458..0fb9aec5 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1013,10 +1013,13 @@ def write_byte(self, value: int) -> None: """Writes a single byte to the packet""" self.pack(b'!c', int2byte(value)) - def insert_short(self, index: int, value: int) -> None: - """Inserts an unsigned short in a certain position in the packet""" - self.data.insert(index, struct.pack(b'!H', value)) - self.size += 2 + def insert_short_at_start(self, value: int) -> None: + """Inserts an unsigned short at the start of the packet""" + self.data.insert(0, struct.pack(b'!H', value)) + + def replace_short(self, index: int, value: int) -> None: + """Replaces an unsigned short in a certain position in the packet""" + self.data[index] = struct.pack(b'!H', value) def write_short(self, value: int) -> None: """Writes an unsigned short to the packet""" @@ -1123,15 +1126,13 @@ def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) self.write_int(record.get_remaining_ttl(now)) index = len(self.data) - # Adjust size for the short we will write before this record - self.size += 2 + self.write_short(0) # Will get replaced with the actual size record.write(self) - self.size -= 2 - - length = sum((len(d) for d in self.data[index:])) - # Here is the short we adjusted for - self.insert_short(index, length) - + # Adjust size for the short we will write before this record + length = sum((len(d) for d in self.data[index + 1 :])) + # Here we replace the 0 length short we wrote + # before with the actual length + self.replace_short(index, length) len_limit = _MAX_MSG_ABSOLUTE if allow_long else _MAX_MSG_TYPICAL # if we go over, then rollback and quit @@ -1139,6 +1140,10 @@ def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) while len(self.data) > start_data_length: self.data.pop() self.size = start_size + + rollback_names = [name for name, idx in self.names.items() if idx >= start_size] + for name in rollback_names: + del self.names[name] return False return True @@ -1207,15 +1212,15 @@ def packets(self) -> List[bytes]: if self.write_record(additional, 0): additionals_written += 1 - self.insert_short(0, additionals_written) - self.insert_short(0, authorities_written) - self.insert_short(0, answers_written) - self.insert_short(0, questions_written) - self.insert_short(0, self.flags) + self.insert_short_at_start(additionals_written) + self.insert_short_at_start(authorities_written) + self.insert_short_at_start(answers_written) + self.insert_short_at_start(questions_written) + self.insert_short_at_start(self.flags) if self.multicast: - self.insert_short(0, 0) + self.insert_short_at_start(0) else: - self.insert_short(0, self.id) + self.insert_short_at_start(self.id) self.packets_data.append(b''.join(self.data)) self.reset_for_next_packet() diff --git a/zeroconf/test.py b/zeroconf/test.py index cfbd2a79..f558e711 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1973,3 +1973,156 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): # unregister zc.unregister_service(info) + + +def test_dns_compression_rollback_for_corruption(): + """Verify rolling back does not lead to dns compression corruption.""" + out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + address = socket.inet_pton(socket.AF_INET, "192.168.208.5") + + additionals = [ + { + "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", + "address": address, + "port": 51832, + "text": b"\x13md=HASS Bridge" + b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=L0m/aQ==", + }, + { + "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", + "address": address, + "port": 51834, + "text": b"\x13md=HASS Bridge" + b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b2CnzQ==", + }, + { + "name": "Master Bed TV CEDB27._hap._tcp.local.", + "address": address, + "port": 51830, + "text": b"\x10md=Master Bed" + b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=CVj1kw==", + }, + { + "name": "Living Room TV 921B77._hap._tcp.local.", + "address": address, + "port": 51833, + "text": b"\x11md=Living Room" + b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=qU77SQ==", + }, + { + "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", + "address": address, + "port": 51829, + "text": b"\x13md=HASS Bridge" + b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b0QZlg==", + }, + { + "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", + "address": address, + "port": 51837, + "text": b"\x13md=HASS Bridge" + b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=ahAISA==", + }, + { + "name": "FrontdoorCamera 8941D1._hap._tcp.local.", + "address": address, + "port": 54898, + "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" + b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", + }, + { + "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + "address": address, + "port": 51836, + "text": b"\x13md=HASS Bridge" + b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=6fLM5A==", + }, + { + "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", + "address": address, + "port": 51838, + "text": b"\x13md=HASS Bridge" + b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=u3bdfw==", + }, + { + "name": "Snooze Room TV 6B89B0._hap._tcp.local.", + "address": address, + "port": 51835, + "text": b"\x11md=Snooze Room" + b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=xNTqsg==", + }, + { + "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", + "address": address, + "port": 54811, + "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" + b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", + }, + { + "name": "HASS Bridge OS95 39C053._hap._tcp.local.", + "address": address, + "port": 51831, + "text": b"\x13md=HASS Bridge" + b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" + b"\x04sf=0\x0bsh=Xfe5LQ==", + }, + ] + + out.add_answer_at_time( + DNSText( + "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for record in additionals: + out.add_additional_answer( + r.DNSService( + record["name"], # type: ignore + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + record["port"], # type: ignore + record["name"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSText( + record["name"], # type: ignore + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + record["text"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSAddress( + record["name"], # type: ignore + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + record["address"], # type: ignore + ) + ) + + for packet in out.packets(): + # Verify we can process the packets we created to + # ensure there is no corruption with the dns compression + incoming = r.DNSIncoming(packet) + assert incoming.valid is True From 1d726b551a49e945b134df6e29b352697030c5a9 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 4 Jan 2021 19:47:33 +0100 Subject: [PATCH 083/608] Release version 0.28.8 --- README.rst | 6 ++++++ zeroconf/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 0bf8fcdc..0fe2164d 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,12 @@ See examples directory for more. Changelog ========= +0.28.8 +====== + +* Fixed the packet generation when multiple packets are necessary, previously invalid + packets were generated sometimes. Patch thanks to J. Nick Koston. + 0.28.7 ====== diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0fb9aec5..2535e8ed 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.7' +__version__ = '0.28.8' __license__ = 'LGPL' From c5a675d22788aa905a4e47feb1d4c30f30416356 Mon Sep 17 00:00:00 2001 From: Pack3tL0ss Date: Wed, 27 Jan 2021 01:52:09 -0600 Subject: [PATCH 084/608] Fix link to readme md --> rst (#324) --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index c4fa6143..59952189 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,4 +28,4 @@ Contents api -See `the project's README `_ for more information. +See `the project's README `_ for more information. From 5e268faeaa99f0a513c7bbeda8f447f4eb36a747 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 4 Feb 2021 09:25:25 -1000 Subject: [PATCH 085/608] Simplify read_name (venv) root@ha-dev:~/python-zeroconf# python3 -m timeit -s 'result=""' -u usec 'result = "".join((result, "thisisaname" + "."))' 20000 loops, best of 5: 16.4 usec per loop (venv) root@ha-dev:~/python-zeroconf# python3 -m timeit -s 'result=""' -u usec 'result += "thisisaname" + "."' 2000000 loops, best of 5: 0.105 usec per loop --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 2535e8ed..dbeef01b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -880,7 +880,7 @@ def read_name(self) -> str: break t = length & 0xC0 if t == 0x00: - result = ''.join((result, self.read_utf(off, length) + '.')) + result += self.read_utf(off, length) + '.' off += length elif t == 0xC0: if next_ < 0: From 6beefbbe76a0e261394b308c8cc68545be653019 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 21 Mar 2021 09:20:16 -1000 Subject: [PATCH 086/608] Use a single socket for InterfaceChoice.Default When using multiple sockets with multi-cast, the outgoing socket's responses could be read back on the incoming socket, which leads to duplicate processing and could fill up the incoming buffer before it could be processed. This behavior manifested with error similar to `OSError: [Errno 105] No buffer space available` By using a single socket with InterfaceChoice.Default we avoid this case. --- zeroconf/__init__.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dbeef01b..e21c3e80 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2255,8 +2255,7 @@ def new_socket( def add_multicast_member( listen_socket: socket.socket, interface: Union[str, Tuple[Tuple[str, int, int], int]], - apple_p2p: bool = False, -) -> Optional[socket.socket]: +) -> None: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, tuple) err_einval = {errno.EINVAL} @@ -2293,6 +2292,12 @@ def add_multicast_member( else: raise + +def new_respond_socket( + interface: Union[str, Tuple[Tuple[str, int, int], int]], + apple_p2p: bool = False, +) -> Optional[socket.socket]: + is_v6 = isinstance(interface, tuple) respond_socket = new_socket( ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), apple_p2p=apple_p2p, @@ -2300,6 +2305,7 @@ def add_multicast_member( ) log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) if is_v6: + iface_bin = struct.pack('@I', cast(int, interface[1])) respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) else: respond_socket.setsockopt( @@ -2321,11 +2327,19 @@ def create_sockets( normalized_interfaces = normalize_interface_choice(interfaces, ip_version) + # If we are using InterfaceChoice.Default we can use + # a single socket to listen and respond. + if not unicast and interfaces is InterfaceChoice.Default: + for i in normalized_interfaces: + add_multicast_member(cast(socket.socket, listen_socket), i) + return listen_socket, [listen_socket] + respond_sockets = [] for i in normalized_interfaces: if not unicast: - respond_socket = add_multicast_member(cast(socket.socket, listen_socket), i, apple_p2p=apple_p2p) + add_multicast_member(cast(socket.socket, listen_socket), i) + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) else: respond_socket = new_socket( port=0, @@ -2494,6 +2508,7 @@ def __init__( interfaces, unicast, ip_version, apple_p2p=apple_p2p ) log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) + self.multi_socket = unicast or interfaces is not InterfaceChoice.Default self.listeners = [] # type: List[RecordUpdateListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] @@ -2513,8 +2528,9 @@ def __init__( self.listener = Listener(self) if not unicast: self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) - for s in self._respond_sockets: - self.engine.add_reader(self.listener, s) + if self.multi_socket: + for s in self._respond_sockets: + self.engine.add_reader(self.listener, s) self.reaper = Reaper(self) self.debug = None # type: Optional[DNSOutgoing] @@ -2978,8 +2994,9 @@ def close(self) -> None: if not self.unicast: self.engine.del_reader(cast(socket.socket, self._listen_socket)) cast(socket.socket, self._listen_socket).close() - for s in self._respond_sockets: - self.engine.del_reader(s) + if self.multi_socket: + for s in self._respond_sockets: + self.engine.del_reader(s) self.engine.join() # shutdown the rest From 3e6f24a5fd562d3ee3985cc3cb83bcb276abe9d3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 24 Mar 2021 11:33:59 -1000 Subject: [PATCH 087/608] cast listen_socket to socket.socket in create_sockets Resolves typing error --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e21c3e80..86750651 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2332,7 +2332,7 @@ def create_sockets( if not unicast and interfaces is InterfaceChoice.Default: for i in normalized_interfaces: add_multicast_member(cast(socket.socket, listen_socket), i) - return listen_socket, [listen_socket] + return listen_socket, [cast(socket.socket, listen_socket)] respond_sockets = [] From ab67a7aecd63042178061f0d1a76f9a7f6e1559a Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 24 Mar 2021 23:23:23 +0100 Subject: [PATCH 088/608] Drop Python 3.5 compatibilty, it reached its end of life --- .travis.yml | 6 +----- Makefile | 5 +---- README.rst | 9 ++++++++- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6937b7cd..c9d32a7f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,9 @@ language: python python: - - "3.5" - "3.6" - "3.7" - "3.8" - "3.9-dev" - - "pypy3.5" - "pypy3" matrix: allow_failures: @@ -13,9 +11,7 @@ matrix: install: - pip install --upgrade -r requirements-dev.txt # mypy can't be installed on pypy - - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install mypy ; fi - - if [[ "${TRAVIS_PYTHON_VERSION}" != *"3.5"* && "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then - pip install black ; fi + - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install black mypy ; fi script: # no IPv6 support in Travis :( - SKIP_IPV6=1 make ci diff --git a/Makefile b/Makefile index af951f26..25fdbb2c 100644 --- a/Makefile +++ b/Makefile @@ -6,10 +6,7 @@ PYTHON_VERSION:=$(shell python -c "import sys;sys.stdout.write('%d.%d' % sys.ver LINT_TARGETS:=flake8 ifneq ($(findstring PyPy,$(PYTHON_IMPLEMENTATION)),PyPy) - LINT_TARGETS:=$(LINT_TARGETS) mypy -endif -ifeq ($(or $(findstring 3.5,$(PYTHON_VERSION)),$(findstring PyPy,$(PYTHON_IMPLEMENTATION))),) - LINT_TARGETS:=$(LINT_TARGETS) black_check + LINT_TARGETS:=$(LINT_TARGETS) mypy black_check endif diff --git a/README.rst b/README.rst index 0fe2164d..4ec6e73f 100644 --- a/README.rst +++ b/README.rst @@ -44,7 +44,7 @@ Compared to some other Zeroconf/Bonjour/Avahi Python packages, python-zeroconf: Python compatibility -------------------- -* CPython 3.5+ +* CPython 3.6+ * PyPy3 5.8+ Versioning @@ -134,6 +134,13 @@ See examples directory for more. Changelog ========= +0.29.0 (not released yet) +========================= + +Backwards incompatible: + +* Dropped Python 3.5 support + 0.28.8 ====== From bd80d20682c0af5e15a4b7102dcfe814cdba3a01 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 24 Mar 2021 23:27:25 +0100 Subject: [PATCH 089/608] Switch from Travis CI/Coveralls to GH Actions/Codecov Travis CI free tier is going away and Codecov is my go-to code coverage service now. Closes GH-332. --- .github/workflows/ci.yml | 33 +++++++++++++++++++++++++++++++++ .travis.yml | 19 ------------------- README.rst | 10 +++++----- docs/index.rst | 12 +++++------- requirements-dev.txt | 2 ++ 5 files changed, 45 insertions(+), 31 deletions(-) create mode 100644 .github/workflows/ci.yml delete mode 100644 .travis.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..86cc95a7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - "**" + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.6, 3.7, 3.8, 3.9, pypy3] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install --upgrade -r requirements-dev.txt + pip install . + - name: Run tests + run: make ci + env: + SKIP_IPV6: 1 + - name: Report coverage to Codecov + uses: codecov/codecov-action@v1 diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index c9d32a7f..00000000 --- a/.travis.yml +++ /dev/null @@ -1,19 +0,0 @@ -language: python -python: - - "3.6" - - "3.7" - - "3.8" - - "3.9-dev" - - "pypy3" -matrix: - allow_failures: - - python: "3.9-dev" -install: - - pip install --upgrade -r requirements-dev.txt - # mypy can't be installed on pypy - - if [[ "${TRAVIS_PYTHON_VERSION}" != "pypy"* ]] ; then pip install black mypy ; fi -script: - # no IPv6 support in Travis :( - - SKIP_IPV6=1 make ci -after_success: - - coveralls diff --git a/README.rst b/README.rst index 4ec6e73f..887fd8e4 100644 --- a/README.rst +++ b/README.rst @@ -1,14 +1,14 @@ python-zeroconf =============== -.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master - :target: https://travis-ci.org/jstasiak/python-zeroconf - +.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg + :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster + .. image:: https://img.shields.io/pypi/v/zeroconf.svg :target: https://pypi.python.org/pypi/zeroconf -.. image:: https://img.shields.io/coveralls/jstasiak/python-zeroconf.svg - :target: https://coveralls.io/r/jstasiak/python-zeroconf +.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jstasiak/python-zeroconf `Documentation `_. diff --git a/docs/index.rst b/docs/index.rst index 59952189..de5ba41a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,16 +1,14 @@ Welcome to python-zeroconf documentation! ========================================= -.. image:: https://travis-ci.org/jstasiak/python-zeroconf.svg?branch=master - :alt: Build status - :target: https://travis-ci.org/jstasiak/python-zeroconf +.. image:: https://github.com/jstasiak/python-zeroconf/workflows/CI/badge.svg + :target: https://github.com/jstasiak/python-zeroconf?query=workflow%3ACI+branch%3Amaster .. image:: https://img.shields.io/pypi/v/zeroconf.svg :target: https://pypi.python.org/pypi/zeroconf - -.. image:: https://coveralls.io/repos/github/jstasiak/python-zeroconf/badge.svg?branch=master - :alt: Covergage status - :target: https://coveralls.io/github/jstasiak/python-zeroconf?branch=master + +.. image:: https://codecov.io/gh/jstasiak/python-zeroconf/branch/master/graph/badge.svg + :target: https://codecov.io/gh/jstasiak/python-zeroconf GitHub (code repository, issues): https://github.com/jstasiak/python-zeroconf diff --git a/requirements-dev.txt b/requirements-dev.txt index 2d0490ae..8c1527b4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,12 @@ autopep8 +black;implementation_name=="cpython" coveralls coverage # Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 flake8>=3.6.0 flake8-import-order ifaddr +mypy;implementation_name=="cpython" # 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152 pep8-naming!=0.6.0,!=0.11.0 pytest From 6482da05344e6ae8c4da440da4a704a20c344bb6 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 24 Mar 2021 23:44:46 +0100 Subject: [PATCH 090/608] Silence a mypy false-positive --- zeroconf/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 86750651..53d49ff4 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2260,7 +2260,8 @@ def add_multicast_member( is_v6 = isinstance(interface, tuple) err_einval = {errno.EINVAL} if sys.platform == 'win32': - err_einval |= {errno.WSAEINVAL} + # No WSAEINVAL definition in typeshed + err_einval |= {cast(Any, errno).WSAEINVAL} log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: From bc6ef8c65b22d982798104d5bdf11b78746a8ddd Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 24 Mar 2021 23:57:58 +0100 Subject: [PATCH 091/608] Silence a flaky test on PyPy --- zeroconf/test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index f558e711..fce2512f 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -7,6 +7,7 @@ import copy import logging import os +import platform import socket import struct import threading @@ -17,6 +18,8 @@ from typing import Dict, Optional # noqa # used in type hints from typing import cast +import pytest + import zeroconf as r from zeroconf import ( DNSHinfo, @@ -1126,6 +1129,7 @@ def test_integration_with_subtype_and_listener(self): class ListenerTest(unittest.TestCase): + @pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="Flaky on PyPy") def test_integration_with_listener_class(self): service_added = Event() From f871b90d25c0f788590ceb14237b08a6b5e6eeeb Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 25 Mar 2021 00:10:19 +0100 Subject: [PATCH 092/608] Make mypy configuration more lenient We want to be able to call untyped modules. --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5610cf68..d4354ef4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,7 +6,7 @@ ignore=E203,W503 [mypy] ignore_missing_imports = true -follow_imports = error +follow_imports = ignore check_untyped_defs = true no_implicit_optional = true warn_incomplete_stub = true From 53cb8044bfb4256f570d438817fd37acc8b78511 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 25 Mar 2021 00:00:35 +0100 Subject: [PATCH 093/608] Fill a missing changelog entry --- README.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.rst b/README.rst index 887fd8e4..9b8a4b98 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,9 @@ Changelog 0.29.0 (not released yet) ========================= +* A single socket is used for listening on responding when `InterfaceChoice.Default` is chosen. + Thanks to J. Nick Koston. + Backwards incompatible: * Dropped Python 3.5 support From 203ec2e26e6f0f676e7d88b4a1b0c80ad74659f1 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Thu, 25 Mar 2021 00:01:25 +0100 Subject: [PATCH 094/608] Release version 0.29.0 --- README.rst | 4 ++-- zeroconf/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 9b8a4b98..956eab05 100644 --- a/README.rst +++ b/README.rst @@ -134,8 +134,8 @@ See examples directory for more. Changelog ========= -0.29.0 (not released yet) -========================= +0.29.0 +====== * A single socket is used for listening on responding when `InterfaceChoice.Default` is chosen. Thanks to J. Nick Koston. diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 53d49ff4..d8890565 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -42,7 +42,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.28.8' +__version__ = '0.29.0' __license__ = 'LGPL' From fe948105cc0923336ffa6d93cbe7d45470612a36 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 2 May 2021 20:00:57 -1000 Subject: [PATCH 095/608] Simplify cache iteration (#340) - Remove the need to trap runtime error - Only copy the names of the keys when iterating the cache - Fixes RuntimeError: list changed size during iterating entries_from_name - Cache services - The Repear thread is no longer aware of the cache internals --- zeroconf/__init__.py | 87 +++++++++++++++++--------------------------- zeroconf/test.py | 41 ++++----------------- 2 files changed, 42 insertions(+), 86 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d8890565..ca930386 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -23,7 +23,6 @@ import enum import errno import ipaddress -import itertools import logging import platform import re @@ -1243,23 +1242,29 @@ class DNSCache: def __init__(self) -> None: self.cache = {} # type: Dict[str, List[DNSRecord]] + self.service_cache = {} # type: Dict[str, List[DNSRecord]] def add(self, entry: DNSRecord) -> None: """Adds an entry""" # Insert last in list, get will return newest entry # iteration will result in last update winning self.cache.setdefault(entry.key, []).append(entry) + if isinstance(entry, DNSService): + self.service_cache.setdefault(entry.server, []).append(entry) def remove(self, entry: DNSRecord) -> None: - """Removes an entry""" + """Removes an entry.""" + if isinstance(entry, DNSService): + DNSCache.remove_key(self.service_cache, entry.server, entry) + DNSCache.remove_key(self.cache, entry.key, entry) + + @staticmethod + def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: + """Forgiving remove of a cache key.""" try: - list_ = self.cache[entry.key] - list_.remove(entry) - # If we remove the last entry in the list - # we remove the key from the dict in order - # to avoid leaking memory - if not list_: - del self.cache[entry.key] + cache[key].remove(entry) + if not cache[key]: + del cache[key] except (KeyError, ValueError): pass @@ -1281,12 +1286,13 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco entry = DNSEntry(name, type_, class_) return self.get(entry) + def entries_with_server(self, server: str) -> List[DNSRecord]: + """Returns a list of entries whose server matches the name.""" + return self.service_cache.get(server, [])[:] + def entries_with_name(self, name: str) -> List[DNSRecord]: """Returns a list of entries whose key matches the name.""" - try: - return self.cache[name.lower()] - except KeyError: - return [] + return self.cache.get(name.lower(), [])[:] def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() @@ -1299,25 +1305,17 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D return record return None - def entries(self) -> List[DNSRecord]: - """Returns a list of all entries""" - if not self.cache: - return [] - - # avoid size change during iteration by copying the cache - return list(itertools.chain.from_iterable(list(self.cache.values()))) + def names(self) -> List[str]: + """Return a copy of the list of current cache names.""" + return list(self.cache) - def iterable_entries(self) -> Iterable[DNSRecord]: - """Returns an iterable of all entries. - - This function is provided to avoid copying - the entries but is not threadsafe as the - contents of the cache can change during iteration. - - Callers should trap RuntimeError and fallback - to calling entries. - """ - return itertools.chain.from_iterable(self.cache.values()) + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache.""" + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record class Engine(threading.Thread): @@ -1486,22 +1484,10 @@ def run(self) -> None: if self.zc.done: return - try: - # We try to iterate the cache without copying the whole - # cache as this can be quite an expensive operation. - self._cleanup_cache(self.zc.cache.iterable_entries()) - except RuntimeError: - # If the cache changes during iteration, we fallback - # to making a copy before iteraiton. - self._cleanup_cache(self.zc.cache.entries()) - - def _cleanup_cache(self, entries: Iterable[DNSRecord]) -> None: - """Remove expired entries from the cache.""" - now = current_time_millis() - for record in entries: - if record.is_expired(now): + + now = current_time_millis() + for record in self.zc.cache.expire(now): self.zc.update_record(now, record) - self.zc.cache.remove(record) class Signal: @@ -1706,9 +1692,7 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> return # Iterate through the DNSCache and callback any services that use this address - for service in zc.cache.entries(): - if not isinstance(service, DNSService) or not service.server == record.name: - continue + for service in self.zc.cache.entries_with_server(record.name): for type_ in self.types: if service.name.endswith(type_): enqueue_callback(ServiceStateChange.Updated, type_, service.name) @@ -2772,10 +2756,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # we can avoid iterating everything in the cache and # only look though entries for the specific name. # entries_with_name will take care of converting to lowercase - # - # We make a copy of the list that entries_with_name returns - # since we cannot iterate over something we might remove - for entry in self.cache.entries_with_name(record.name).copy(): + for entry in self.cache.entries_with_name(record.name): if entry == record: updated = False diff --git a/zeroconf/test.py b/zeroconf/test.py index fce2512f..7349b42a 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -5,6 +5,7 @@ """ Unit tests for zeroconf.py """ import copy +import itertools import logging import os import platform @@ -958,45 +959,18 @@ def test_cache_empty_multiple_calls_does_not_throw(self): class TestReaper(unittest.TestCase): def test_reaper(self): zeroconf = Zeroconf(interfaces=['127.0.0.1']) - original_entries = zeroconf.cache.entries() + cache = zeroconf.cache + original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') zeroconf.cache.add(record_with_10s_ttl) zeroconf.cache.add(record_with_1s_ttl) - entries_with_cache = zeroconf.cache.entries() + entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) time.sleep(1.05) zeroconf.notify_reaper() time.sleep(0.05) - entries = zeroconf.cache.entries() - - try: - iterable_entries = list(zeroconf.cache.iterable_entries()) - finally: - zeroconf.close() - - assert entries != original_entries - assert entries_with_cache != original_entries - assert record_with_10s_ttl in entries - assert record_with_1s_ttl not in entries - assert record_with_10s_ttl in iterable_entries - assert record_with_1s_ttl not in iterable_entries - - def test_reaper_with_dict_change_during_iteration(self): - zeroconf = Zeroconf(interfaces=['127.0.0.1']) - original_entries = zeroconf.cache.entries() - record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') - record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - zeroconf.cache.add(record_with_10s_ttl) - zeroconf.cache.add(record_with_1s_ttl) - entries_with_cache = zeroconf.cache.entries() - with unittest.mock.patch("zeroconf.DNSCache.iterable_entries", side_effect=RuntimeError): - time.sleep(1.05) - zeroconf.notify_reaper() - time.sleep(0.05) - - entries = zeroconf.cache.entries() + entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) zeroconf.close() - assert entries != original_entries assert entries_with_cache != original_entries assert record_with_10s_ttl in entries @@ -1196,8 +1170,9 @@ def update_service(self, zeroconf, type, name): time.sleep(3) # clear the answer cache to force query - for record in zeroconf_browser.cache.entries(): - zeroconf_browser.cache.remove(record) + for name in zeroconf_browser.cache.names(): + for record in zeroconf_browser.cache.entries_with_name(name): + zeroconf_browser.cache.remove(record) # get service info without answer cache info = zeroconf_browser.get_service_info(type_, registration_name) From beccad1f0b41730f541b2e90ea2eaa2496de5044 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 2 May 2021 20:16:26 -1000 Subject: [PATCH 096/608] Skip socket creation if add_multicast_member fails (windows) (#341) Co-authored-by: Timothee 'TTimo' Besset --- zeroconf/__init__.py | 15 +++++++++------ zeroconf/test.py | 19 +++++++++++++++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ca930386..76f32174 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2239,7 +2239,7 @@ def new_socket( def add_multicast_member( listen_socket: socket.socket, interface: Union[str, Tuple[Tuple[str, int, int], int]], -) -> None: +) -> bool: # This is based on assumptions in normalize_interface_choice is_v6 = isinstance(interface, tuple) err_einval = {errno.EINVAL} @@ -2263,19 +2263,20 @@ def add_multicast_member( 'it is expected to happen on some systems', interface, ) - return None + return False elif _errno == errno.EADDRNOTAVAIL: log.info( 'Address not available when adding %s to multicast ' 'group, it is expected to happen on some systems', interface, ) - return None + return False elif _errno in err_einval: log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) - return None + return False else: raise + return True def new_respond_socket( @@ -2323,8 +2324,10 @@ def create_sockets( for i in normalized_interfaces: if not unicast: - add_multicast_member(cast(socket.socket, listen_socket), i) - respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + if add_multicast_member(cast(socket.socket, listen_socket), i): + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + else: + respond_socket = None else: respond_socket = new_socket( port=0, diff --git a/zeroconf/test.py b/zeroconf/test.py index 7349b42a..60bfb4d5 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -5,6 +5,7 @@ """ Unit tests for zeroconf.py """ import copy +import errno import itertools import logging import os @@ -16,8 +17,7 @@ import unittest import unittest.mock from threading import Event -from typing import Dict, Optional # noqa # used in type hints -from typing import cast +from typing import Dict, Optional, cast # noqa # used in type hints import pytest @@ -2105,3 +2105,18 @@ def test_dns_compression_rollback_for_corruption(): # ensure there is no corruption with the dns compression incoming = r.DNSIncoming(packet) assert incoming.valid is True + + +@pytest.mark.parametrize( + "errno,expected_result", + [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], +) +def test_add_multicast_member_socket_errors(errno, expected_result): + """Test we handle socket errors when adding multicast members.""" + if errno: + setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno))) + else: + setsockopt_mock = unittest.mock.Mock() + fileno_mock = unittest.mock.PropertyMock(return_value=10) + socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) + assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result From 523aefb0b0c477489e4e1e4ab763ce56c57295b7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 3 May 2021 16:10:41 -1000 Subject: [PATCH 097/608] Return early when already closed (#350) - Reduce indentation with a return early guard in close --- zeroconf/__init__.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 76f32174..0284b51d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2969,24 +2969,25 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P def close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" - if not self._GLOBAL_DONE: - # remove service listeners - self.remove_all_service_listeners() - self.unregister_all_services() - self._GLOBAL_DONE = True - - # shutdown recv socket and thread - if not self.unicast: - self.engine.del_reader(cast(socket.socket, self._listen_socket)) - cast(socket.socket, self._listen_socket).close() - if self.multi_socket: - for s in self._respond_sockets: - self.engine.del_reader(s) - self.engine.join() - - # shutdown the rest - self.notify_all() - self.notify_reaper() - self.reaper.join() + if self._GLOBAL_DONE: + return + # remove service listeners + self.remove_all_service_listeners() + self.unregister_all_services() + self._GLOBAL_DONE = True + + # shutdown recv socket and thread + if not self.unicast: + self.engine.del_reader(cast(socket.socket, self._listen_socket)) + cast(socket.socket, self._listen_socket).close() + if self.multi_socket: for s in self._respond_sockets: - s.close() + self.engine.del_reader(s) + self.engine.join() + + # shutdown the rest + self.notify_all() + self.notify_reaper() + self.reaper.join() + for s in self._respond_sockets: + s.close() From 781627864efbb3c8285e1b75144d688083414cf3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 3 May 2021 21:47:36 -1000 Subject: [PATCH 098/608] Eliminate the reaper thread (#349) - Cache is now purged between reads when the interval is reached - Reduce locking since we are already making a copy of the readers and not reading under the lock - Simplify shutdown process --- zeroconf/__init__.py | 97 +++++++++++++++++--------------------------- zeroconf/test.py | 19 +++++++-- 2 files changed, 53 insertions(+), 63 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0284b51d..ba588644 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -34,7 +34,7 @@ import time import warnings from collections import OrderedDict -from typing import Dict, Iterable, List, Optional, Sequence, Union, cast +from typing import Dict, Iterable, List, Optional, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints import ifaddr @@ -1337,39 +1337,51 @@ def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.readers = {} # type: Dict[socket.socket, Listener] self.timeout = 5 + self.cache_cleanup_interval_ms = 10000.0 self.condition = threading.Condition() self.socketpair = socket.socketpair() + self._last_cache_cleanup = 0.0 self.start() self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) def run(self) -> None: while not self.zc.done: - with self.condition: - rs = list(self.readers.keys()) - if len(rs) == 0: - # No sockets to manage, but we wait for the timeout - # or addition of a socket + rs = list(self.readers.keys()) + if not rs: + # No sockets to manage, but we wait for the timeout + # or addition of a socket + with self.condition: self.condition.wait(self.timeout) + continue + + try: + rs.append(self.socketpair[0]) + rr, wr, er = select.select(rs, [], [], self.timeout) + + if self.zc.done: + return + + for socket_ in rr: + reader = self.readers.get(socket_) + if reader: + reader.handle_read(socket_) + + if self.socketpair[0] in rr: + # Clear the socket's buffer + self.socketpair[0].recv(128) + + except (select.error, socket.error) as e: + # If the socket was closed by another thread, during + # shutdown, ignore it and exit + if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: + raise + + now = current_time_millis() + if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms: + self._last_cache_cleanup = now + for record in self.zc.cache.expire(now): + self.zc.update_record(now, record) - if len(rs) != 0: - try: - rs = rs + [self.socketpair[0]] - rr, wr, er = select.select(cast(Sequence[Any], rs), [], [], self.timeout) - if not self.zc.done: - for socket_ in rr: - reader = self.readers.get(socket_) - if reader: - reader.handle_read(socket_) - - if self.socketpair[0] in rr: - # Clear the socket's buffer - self.socketpair[0].recv(128) - - except (select.error, socket.error) as e: - # If the socket was closed by another thread, during - # shutdown, ignore it and exit - if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: - raise self.socketpair[0].close() self.socketpair[1].close() @@ -1464,32 +1476,6 @@ def handle_read(self, socket_: socket.socket) -> None: self.zc.handle_response(msg) -class Reaper(threading.Thread): - - """A Reaper is used by this module to remove cache entries that - have expired.""" - - def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self) - self.daemon = True - self.zc = zc - self.start() - self.name = "zeroconf-Reaper_%s" % (getattr(self, 'native_id', self.ident),) - - def run(self) -> None: - """Perodic removal of expired entries from the cache.""" - while True: - with self.zc.reaper_condition: - self.zc.reaper_condition.wait(10) - - if self.zc.done: - return - - now = current_time_millis() - for record in self.zc.cache.expire(now): - self.zc.update_record(now, record) - - class Signal: def __init__(self) -> None: self._handlers = [] # type: List[Callable[..., None]] @@ -2505,7 +2491,6 @@ def __init__( self.cache = DNSCache() self.condition = threading.Condition() - self.reaper_condition = threading.Condition() # Ensure we create the lock before # we add the listener as we could get @@ -2519,7 +2504,6 @@ def __init__( if self.multi_socket: for s in self._respond_sockets: self.engine.add_reader(self.listener, s) - self.reaper = Reaper(self) self.debug = None # type: Optional[DNSOutgoing] @@ -2538,11 +2522,6 @@ def notify_all(self) -> None: with self.condition: self.condition.notify_all() - def notify_reaper(self) -> None: - """Notifies reaper""" - with self.reaper_condition: - self.reaper_condition.notify_all() - def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, @@ -2987,7 +2966,5 @@ def close(self) -> None: # shutdown the rest self.notify_all() - self.notify_reaper() - self.reaper.join() for s in self._respond_sockets: s.close() diff --git a/zeroconf/test.py b/zeroconf/test.py index 60bfb4d5..044909a3 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -529,6 +529,17 @@ def test_launch_and_close(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) rv.close() + def test_launch_and_close_unicast(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True) + rv.close() + + def test_close_multiple_times(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) + rv.close() + rv.close() + @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_launch_and_close_v4_v6(self): @@ -966,9 +977,11 @@ def test_reaper(self): zeroconf.cache.add(record_with_10s_ttl) zeroconf.cache.add(record_with_1s_ttl) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - time.sleep(1.05) - zeroconf.notify_reaper() - time.sleep(0.05) + zeroconf.engine.cache_cleanup_interval_ms = 10 + time.sleep(1) + with zeroconf.engine.condition: + zeroconf.engine._notify() + time.sleep(0.1) entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) zeroconf.close() assert entries != original_entries From a41d7b8aa5572f3faf29eb087cc18a1343bbcdfa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 5 May 2021 15:09:56 -0500 Subject: [PATCH 099/608] Provide an asyncio class for service registration (#347) * Provide an AIO wrapper for service registration - When using zeroconf with async code, service registration can cause the executor to overload when registering multiple services since each one will have to wait a bit between sending the broadcast. An aio subclass is now available as aio.AsyncZeroconf that implements the following - async_register_service - async_unregister_service - async_update_service - async_close I/O is currently run in the executor to provide backwards compat with existing use cases. These functions avoid overloading the executor by waiting in the event loop instead of the executor threads. --- Makefile | 4 +- requirements-dev.txt | 1 + zeroconf/__init__.py | 74 +++++++++++--------- zeroconf/asyncio.py | 137 ++++++++++++++++++++++++++++++++++++ zeroconf/test_asyncio.py | 145 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 326 insertions(+), 35 deletions(-) create mode 100644 zeroconf/asyncio.py create mode 100644 zeroconf/test_asyncio.py diff --git a/Makefile b/Makefile index 25fdbb2c..fed4d9b1 100644 --- a/Makefile +++ b/Makefile @@ -36,10 +36,10 @@ mypy: mypy examples/*.py zeroconf/*.py test: - pytest -v zeroconf/test.py + pytest -v zeroconf/test.py zeroconf/test_asyncio.py test_coverage: - pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py + pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py zeroconf/test_asyncio.py autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf diff --git a/requirements-dev.txt b/requirements-dev.txt index 8c1527b4..30b906e4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,4 +10,5 @@ mypy;implementation_name=="cpython" # 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152 pep8-naming!=0.6.0,!=0.11.0 pytest +pytest-asyncio pytest-cov diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ba588644..484a7ee2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -353,6 +353,16 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: return service_name + trailer +def instance_name_from_service_info(info: "ServiceInfo") -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + # Exceptions @@ -2505,8 +2515,6 @@ def __init__( for s in self._respond_sockets: self.engine.add_reader(self.listener, s) - self.debug = None # type: Optional[DNSOutgoing] - @property def done(self) -> bool: return self._GLOBAL_DONE @@ -2580,6 +2588,7 @@ def update_service(self, info: ServiceInfo) -> None: self._broadcast_service(info, _REGISTER_TIME, None) def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + """Send a broadcasts to announce a service at intervals.""" now = current_time_millis() next_time = now i = 0 @@ -2589,12 +2598,23 @@ def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int now = current_time_millis() continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - self._add_broadcast_answer(out, info, ttl) - self.send(out) + self.send_service_broadcast(info, ttl) i += 1 next_time += interval + def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: + """Send a broadcast to announce a service.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + self._add_broadcast_answer(out, info, ttl) + self.send(out) + + def send_service_query(self, info: ServiceInfo) -> None: + """Send a query to lookup a service.""" + out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) + out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) + out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name)) + self.send(out) + def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]) -> None: """Add answers to broadcast a service.""" other_ttl = info.other_ttl if override_ttl is None else override_ttl @@ -2653,43 +2673,31 @@ def check_service( ) -> None: """Checks the network for a unique service name, modifying the ServiceInfo passed in if it is not unique.""" - - # This is kind of funky because of the subtype based tests - # need to make subtypes a first class citizen - service_name = service_type_name(info.name) - if not info.type.endswith(service_name): - raise BadTypeInNameException - - instance_name = info.name[: -len(service_name) - 1] + instance_name = instance_name_from_service_info(info) + if cooperating_responders: + return next_instance_number = 2 - - now = current_time_millis() - next_time = now + next_time = now = current_time_millis() i = 0 while i < 3: - if not cooperating_responders: - # check for a name conflict - while self.cache.current_entry_with_name_and_alias(info.type, info.name): - if not allow_name_change: - raise NonUniqueNameException - - # change the name and look for a conflict - info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) - next_instance_number += 1 - service_type_name(info.name) - next_time = now - i = 0 + # check for a name conflict + while self.cache.current_entry_with_name_and_alias(info.type, info.name): + if not allow_name_change: + raise NonUniqueNameException + + # change the name and look for a conflict + info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) + next_instance_number += 1 + service_type_name(info.name) + next_time = now + i = 0 if now < next_time: self.wait(next_time - now) now = current_time_millis() continue - out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) - self.debug = out - out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) - out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name)) - self.send(out) + self.send_service_query(info) i += 1 next_time += _CHECK_TIME diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py new file mode 100644 index 00000000..859d26a0 --- /dev/null +++ b/zeroconf/asyncio.py @@ -0,0 +1,137 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" +import asyncio +from typing import Optional + +from . import ( + IPVersion, + InterfaceChoice, + InterfacesType, + NonUniqueNameException, + ServiceInfo, + Zeroconf, + _CHECK_TIME, + _REGISTER_TIME, + _UNREGISTER_TIME, + instance_name_from_service_info, +) + + +class AsyncZeroconf: + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + + The async version is currently a wrapper around the sync version + with I/O being done in the executor for backwards compatibility. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + self.zeroconf = Zeroconf( + interfaces=interfaces, + unicast=unicast, + ip_version=ip_version, + apple_p2p=apple_p2p, + ) + self.loop = asyncio.get_event_loop() + + async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + """Send a broadcasts to announce a service at intervals.""" + for i in range(3): + if i != 0: + await asyncio.sleep(interval / 1000) + await self.loop.run_in_executor(None, self.zeroconf.send_service_broadcast, info, ttl) + + async def async_register_service( + self, + info: ServiceInfo, + cooperating_responders: bool = False, + ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`). + + The service will be broadcast in a task. + """ + await self.async_check_service(info, cooperating_responders) + await self.loop.run_in_executor(None, self.zeroconf.registry.add, info) + asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: + """Checks the network for a unique service name.""" + instance_name_from_service_info(info) + if cooperating_responders: + return + for i in range(3): + # check for a name conflict + if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): + raise NonUniqueNameException + if i != 0: + await asyncio.sleep(_CHECK_TIME / 1000) + await self.loop.run_in_executor(None, self.zeroconf.send_service_query, info) + + async def async_unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service. + + The service will be broadcast in a task. + """ + await self.loop.run_in_executor(None, self.zeroconf.registry.remove, info) + asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + + async def async_update_service(self, info: ServiceInfo) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. + + The service will be broadcast in a task. + """ + await self.loop.run_in_executor(None, self.zeroconf.registry.update, info) + asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + async def async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries.""" + await self.loop.run_in_executor(None, self.zeroconf.close) diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py new file mode 100644 index 00000000..9ad43ce4 --- /dev/null +++ b/zeroconf/test_asyncio.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for async.py.""" + +import asyncio +import socket + +import pytest + +from . import ( + BadTypeInNameException, + NonUniqueNameException, + ServiceInfo, + ServiceListener, + ServiceNameAlreadyRegistered, + Zeroconf, + _REGISTER_TIME, + _UNREGISTER_TIME, +) +from .asyncio import AsyncZeroconf + + +@pytest.mark.asyncio +async def test_async_basic_usage() -> None: + """Test we can create and close the instance.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_registration() -> None: + """Test registering services broadcasts the registration by default.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + await aiozc.async_register_service(info) + await asyncio.sleep(_REGISTER_TIME / 1000 * 3) + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + await aiozc.async_update_service(new_info) + await asyncio.sleep(_REGISTER_TIME / 1000 * 3) + + await aiozc.async_unregister_service(new_info) + await asyncio.sleep(_UNREGISTER_TIME / 1000 * 3) + await aiozc.async_close() + + assert calls == [ + ('add', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('update', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('remove', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_name_conflict() -> None: + """Test registering services throws on name conflict.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + await aiozc.async_register_service(info) + await asyncio.sleep(_REGISTER_TIME / 1000 * 3) + + with pytest.raises(NonUniqueNameException): + await aiozc.async_register_service(info) + + with pytest.raises(ServiceNameAlreadyRegistered): + await aiozc.async_register_service(info, cooperating_responders=True) + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_registration_name_does_not_match_type() -> None: + """Test registering services throws when the name does not match the type.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info.type = "_wrong._tcp.local." + with pytest.raises(BadTypeInNameException): + await aiozc.async_register_service(info) + await aiozc.async_close() From 87ba2a3960576cfcf4207ea74a711b2c0cc584a7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 16 May 2021 16:50:16 -0400 Subject: [PATCH 100/608] Separate cache loading from I/O in ServiceInfo (#356) Provides a load_from_cache method on ServiceInfo that does no I/O - When a ServiceBrowser is running for a type there is no need to make queries on the network since the entries will already be in the cache. When discovering many devices making queries that will almost certainly fail for offline devices delays the startup of online devices. - The DNSEntry and ServiceInfo classes were matching on the name instead of the key (lowercase name). These classes now treat dns names the same reguardless of case. https://datatracker.ietf.org/doc/html/rfc6762#section-16 > The simple rules for case-insensitivity in Unicast DNS [RFC1034] > [RFC1035] also apply in Multicast DNS; that is to say, in name > comparisons, the lowercase letters "a" to "z" (0x61 to 0x7A) match > their uppercase equivalents "A" to "Z" (0x41 to 0x5A). Hence, if a > querier issues a query for an address record with the name > "myprinter.local.", then a responder having an address record with > the name "MyPrinter.local." should issue a response. --- zeroconf/__init__.py | 38 +++++++----- zeroconf/test.py | 136 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 158 insertions(+), 16 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 484a7ee2..28249e2b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -436,9 +436,9 @@ def __init__(self, name: str, type_: int, class_: int) -> None: self.unique = (class_ & _CLASS_UNIQUE) != 0 def __eq__(self, other: Any) -> bool: - """Equality test on name, type, and class""" + """Equality test on key (lowercase name), type, and class""" return ( - self.name == other.name + self.key == other.key and self.type == other.type and self.class_ == other.class_ and isinstance(other, DNSEntry) @@ -1788,6 +1788,7 @@ def __init__( raise BadTypeInNameException self.type = type_ self.name = name + self.key = name.lower() if addresses is not None: self._addresses = addresses elif parsed_addresses is not None: @@ -1807,6 +1808,7 @@ def __init__( self.server = server else: self.server = name + self.server_key = self.server.lower() self._properties = {} # type: Dict self._set_properties(properties) self.host_ttl = host_ttl @@ -1920,34 +1922,28 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) if record is not None and not record.is_expired(now): if record.type in [_TYPE_A, _TYPE_AAAA]: assert isinstance(record, DNSAddress) - # if record.name == self.name: - if record.name == self.server: + if record.key == self.server_key: if record.address not in self._addresses: self._addresses.append(record.address) elif record.type == _TYPE_SRV: assert isinstance(record, DNSService) - if record.name == self.name: + if record.key == self.key: + self.name = record.name self.server = record.server + self.server_key = record.server.lower() self.port = record.port self.weight = record.weight self.priority = record.priority - # self.address = None self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) elif record.type == _TYPE_TXT: assert isinstance(record, DNSText) - if record.name == self.name: + if record.key == self.key: self._set_text(record.text) - def request(self, zc: 'Zeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ + def load_from_cache(self, zc: 'Zeroconf') -> bool: + """Populate the service info from the cache.""" now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)] if self.server is not None: record_types_for_check_cache.append((_TYPE_A, _CLASS_IN)) @@ -1959,7 +1955,19 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if self.server is not None and self.text is not None and self._addresses: return True + return False + + def request(self, zc: 'Zeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(zc): + return True + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout try: zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) while self.server is None or self.text is None or not self._addresses: diff --git a/zeroconf/test.py b/zeroconf/test.py index 044909a3..bb320522 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -243,6 +243,7 @@ def test_suppress_answer(self): # Should not be suppressed, name is different tmp = copy.copy(answer1) + tmp.key = "testname3.local." tmp.name = "testname3.local." response.add_answer(query, tmp) assert len(response.answers) == 2 @@ -1127,7 +1128,7 @@ def test_integration_with_listener_class(self): subtype_name = "My special Subtype" type_ = "_http._tcp.local." subtype = subtype_name + "._sub." + type_ - name = "xxxyyyæøå" + name = "UPPERxxxyyyæøå" registration_name = "%s.%s" % (name, subtype) class MyListener(r.ServiceListener): @@ -1187,6 +1188,10 @@ def update_service(self, zeroconf, type, name): for record in zeroconf_browser.cache.entries_with_name(name): zeroconf_browser.cache.remove(record) + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties == {} + # get service info without answer cache info = zeroconf_browser.get_service_info(type_, registration_name) assert info is not None @@ -1199,10 +1204,35 @@ def update_service(self, zeroconf, type, name): assert info.addresses == addresses[:1] # no V6 by default assert info.addresses_by_version(r.IPVersion.All) == addresses + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + + # get service info with only the cache + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + + # get service info with only the cache with the lowercase name + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + # Ensure uppercase output is preserved + assert cached_info.name == registration_name + assert cached_info.key == registration_name.lower() + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + info = zeroconf_browser.get_service_info(subtype, registration_name) assert info is not None + assert info.properties is not None assert info.properties[b'prop_none'] is None + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_none'] is None + # test TXT record update sublistener = MySubListener() zeroconf_browser.add_service_listener(registration_name, sublistener) @@ -1226,6 +1256,11 @@ def update_service(self, zeroconf, type, name): assert info is not None assert info.properties[b'prop_blank'] == properties['prop_blank'] + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_blank'] == properties['prop_blank'] + zeroconf_registrar.unregister_service(info_service) service_removed.wait(1) assert service_removed.is_set() @@ -1376,6 +1411,105 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi class TestServiceInfo(unittest.TestCase): + def test_service_info_rejects_non_matching_updates(self): + """Verify records with the wrong name are rejected.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.3") + info.update_record( + zc, + now, + r.DNSAddress( + 'ASH-2.local.', + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address in info.addresses + # Non-matching updates + info.update_record( + zc, + now, + r.DNSText( + "incorrect.name.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + "incorrect.name.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.4") + info.update_record( + zc, + now, + r.DNSAddress( + "incorrect.name.", + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address not in info.addresses + def test_get_info_partial(self): zc = r.Zeroconf(interfaces=['127.0.0.1']) From 8c1c394e9b4aa01e08a2c3e240396b533792be55 Mon Sep 17 00:00:00 2001 From: nocarryr Date: Sat, 22 May 2021 11:27:01 -0500 Subject: [PATCH 101/608] Return task objects created by AsyncZeroconf (#360) --- zeroconf/asyncio.py | 23 +++++++------- zeroconf/test_asyncio.py | 65 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 859d26a0..4ba7f312 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -20,7 +20,7 @@ USA """ import asyncio -from typing import Optional +from typing import Awaitable, Optional from . import ( IPVersion, @@ -86,7 +86,7 @@ async def async_register_service( self, info: ServiceInfo, cooperating_responders: bool = False, - ) -> None: + ) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service. The name of the service may be changed if needed to make @@ -94,11 +94,12 @@ async def async_register_service( can register the same service on the network for resilience (if you want this behavior set `cooperating_responders` to `True`). - The service will be broadcast in a task. + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. """ await self.async_check_service(info, cooperating_responders) await self.loop.run_in_executor(None, self.zeroconf.registry.add, info) - asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: """Checks the network for a unique service name.""" @@ -113,23 +114,25 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b await asyncio.sleep(_CHECK_TIME / 1000) await self.loop.run_in_executor(None, self.zeroconf.send_service_query, info) - async def async_unregister_service(self, info: ServiceInfo) -> None: + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service. - The service will be broadcast in a task. + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. """ await self.loop.run_in_executor(None, self.zeroconf.registry.remove, info) - asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) - async def async_update_service(self, info: ServiceInfo) -> None: + async def async_update_service(self, info: ServiceInfo) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service. - The service will be broadcast in a task. + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. """ await self.loop.run_in_executor(None, self.zeroconf.registry.update, info) - asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) async def async_close(self) -> None: """Ends the background threads, and prevent this instance from diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index 9ad43ce4..ed43d2c0 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -143,3 +143,68 @@ async def test_async_service_registration_name_does_not_match_type() -> None: with pytest.raises(BadTypeInNameException): await aiozc.async_register_service(info) await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_tasks() -> None: + """Test awaiting broadcast tasks""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + assert isinstance(task, asyncio.Task) + await task + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + assert isinstance(task, asyncio.Task) + await task + + task = await aiozc.async_unregister_service(new_info) + assert isinstance(task, asyncio.Task) + await task + await aiozc.async_close() + + assert calls == [ + ('add', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('update', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('remove', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ] From c0674e97aee4f61212389337340fc8ff4472eb25 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 22 May 2021 11:28:51 -0500 Subject: [PATCH 102/608] Improve test coverage for name conflicts (#357) --- zeroconf/test.py | 28 ++++++++++++++++++++++++++++ zeroconf/test_asyncio.py | 13 +++++++++++++ 2 files changed, 41 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index bb320522..64005b9b 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -893,6 +893,34 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 + def test_name_conflicts(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_homeassistant._tcp.local." + name = "Home" + registration_name = "%s.%s" % (name, type_) + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + + conflicting_info = ServiceInfo( + type_, + name=registration_name, + server="random456.local.", + addresses=[socket.inet_pton(socket.AF_INET, "4.5.6.7")], + port=80, + properties={"version": "1.0"}, + ) + with pytest.raises(r.NonUniqueNameException): + zc.register_service(conflicting_info) + class TestServiceRegistry(unittest.TestCase): def test_only_register_once(self): diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index ed43d2c0..eb398a1f 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -117,6 +117,19 @@ async def test_async_service_registration_name_conflict() -> None: with pytest.raises(ServiceNameAlreadyRegistered): await aiozc.async_register_service(info, cooperating_responders=True) + conflicting_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-3.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + with pytest.raises(NonUniqueNameException): + await aiozc.async_register_service(conflicting_info) + await aiozc.async_close() From 7e960b78cac8008beca9c5451c6d465e2674a050 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 26 May 2021 22:04:20 -0500 Subject: [PATCH 103/608] Small cleanups to asyncio tests (#362) --- zeroconf/test_asyncio.py | 53 +++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index eb398a1f..e048a603 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -16,8 +16,6 @@ ServiceListener, ServiceNameAlreadyRegistered, Zeroconf, - _REGISTER_TIME, - _UNREGISTER_TIME, ) from .asyncio import AsyncZeroconf @@ -33,7 +31,7 @@ async def test_async_basic_usage() -> None: async def test_async_service_registration() -> None: """Test registering services broadcasts the registration by default.""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc-type._tcp.local." + type_ = "_test1-srvc-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -63,8 +61,8 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - await aiozc.async_register_service(info) - await asyncio.sleep(_REGISTER_TIME / 1000 * 3) + task = await aiozc.async_register_service(info) + await task new_info = ServiceInfo( type_, registration_name, @@ -75,17 +73,16 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")], ) - await aiozc.async_update_service(new_info) - await asyncio.sleep(_REGISTER_TIME / 1000 * 3) - - await aiozc.async_unregister_service(new_info) - await asyncio.sleep(_UNREGISTER_TIME / 1000 * 3) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task await aiozc.async_close() assert calls == [ - ('add', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), - ('update', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), - ('remove', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), ] @@ -93,7 +90,7 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: async def test_async_service_registration_name_conflict() -> None: """Test registering services throws on name conflict.""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-srvc2-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -108,14 +105,16 @@ async def test_async_service_registration_name_conflict() -> None: "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - await aiozc.async_register_service(info) - await asyncio.sleep(_REGISTER_TIME / 1000 * 3) + task = await aiozc.async_register_service(info) + await task with pytest.raises(NonUniqueNameException): - await aiozc.async_register_service(info) + task = await aiozc.async_register_service(info) + await task with pytest.raises(ServiceNameAlreadyRegistered): - await aiozc.async_register_service(info, cooperating_responders=True) + task = await aiozc.async_register_service(info, cooperating_responders=True) + await task conflicting_info = ServiceInfo( type_, @@ -127,8 +126,10 @@ async def test_async_service_registration_name_conflict() -> None: "ash-3.local.", addresses=[socket.inet_aton("10.0.1.3")], ) + with pytest.raises(NonUniqueNameException): - await aiozc.async_register_service(conflicting_info) + task = await aiozc.async_register_service(conflicting_info) + await task await aiozc.async_close() @@ -137,7 +138,7 @@ async def test_async_service_registration_name_conflict() -> None: async def test_async_service_registration_name_does_not_match_type() -> None: """Test registering services throws when the name does not match the type.""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-srvc3-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -154,7 +155,8 @@ async def test_async_service_registration_name_does_not_match_type() -> None: ) info.type = "_wrong._tcp.local." with pytest.raises(BadTypeInNameException): - await aiozc.async_register_service(info) + task = await aiozc.async_register_service(info) + await task await aiozc.async_close() @@ -163,7 +165,7 @@ async def test_async_tasks() -> None: """Test awaiting broadcast tasks""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-srvc4-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -214,10 +216,11 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: task = await aiozc.async_unregister_service(new_info) assert isinstance(task, asyncio.Task) await task + await aiozc.async_close() assert calls == [ - ('add', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), - ('update', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), - ('remove', '_test-srvc-type._tcp.local.', 'xxxyyy._test-srvc-type._tcp.local.'), + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), ] From d8c32401ada4f430cd75617324b6d8ecd1dbe1f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 26 May 2021 22:26:54 -0500 Subject: [PATCH 104/608] Add new cache function get_all_by_details (#363) - When working with IPv6, multiple AAAA records can exist for a given host. get_by_details would only return the latest record in the cache. - Fix a case where the cache list can change during iteration --- zeroconf/__init__.py | 23 +++++++++++------------ zeroconf/test.py | 2 ++ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 28249e2b..98b16d93 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1281,20 +1281,19 @@ def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no matching entry.""" - try: - list_ = self.cache[entry.key] - for cached_entry in reversed(list_): - if entry.__eq__(cached_entry): - return cached_entry - return None - except (KeyError, ValueError): - return None + for cached_entry in reversed(self.entries_with_name(entry.key)): + if entry.__eq__(cached_entry): + return cached_entry + return None def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets an entry by details. Will return None if there is - no matching entry.""" - entry = DNSEntry(name, type_, class_) - return self.get(entry) + """Gets the first matching entry by details. Returns None if no entries match.""" + return self.get(DNSEntry(name, type_, class_)) + + def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details.""" + match_entry = DNSEntry(name, type_, class_) + return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" diff --git a/zeroconf/test.py b/zeroconf/test.py index 64005b9b..edc6f55f 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -610,6 +610,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) assert dns_text is not None assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' + all_dns_text = zeroconf.cache.get_all_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert [dns_text] == all_dns_text # https://tools.ietf.org/html/rfc6762#section-10.2 # Instead of merging this new record additively into the cache in addition From 1b8b2917e7e70e3996e9a96204dd5df3dfb39072 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 26 May 2021 23:09:12 -0500 Subject: [PATCH 105/608] Small cleanup of ServiceInfo.update_record (#364) - Return as record is not viable (None or expired) - Switch checks to isinstance since its needed by mypy anyways - Prepares for supporting multiple AAAA records (via https://github.com/jstasiak/python-zeroconf/pull/361) --- zeroconf/__init__.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 98b16d93..ae4344a7 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1918,27 +1918,25 @@ def get_name(self) -> str: def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: """Updates service information from a DNS record""" - if record is not None and not record.is_expired(now): - if record.type in [_TYPE_A, _TYPE_AAAA]: - assert isinstance(record, DNSAddress) - if record.key == self.server_key: - if record.address not in self._addresses: - self._addresses.append(record.address) - elif record.type == _TYPE_SRV: - assert isinstance(record, DNSService) - if record.key == self.key: - self.name = record.name - self.server = record.server - self.server_key = record.server.lower() - self.port = record.port - self.weight = record.weight - self.priority = record.priority - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) - elif record.type == _TYPE_TXT: - assert isinstance(record, DNSText) - if record.key == self.key: - self._set_text(record.text) + if record is None or record.is_expired(now): + return + if isinstance(record, DNSAddress): + if record.key == self.server_key and record.address not in self._addresses: + self._addresses.append(record.address) + elif isinstance(record, DNSService): + if record.key != self.key: + return + self.name = record.name + self.server = record.server + self.server_key = record.server.lower() + self.port = record.port + self.weight = record.weight + self.priority = record.priority + self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) + self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) + elif isinstance(record, DNSText): + if record.key == self.key: + self._set_text(record.text) def load_from_cache(self, zc: 'Zeroconf') -> bool: """Populate the service info from the cache.""" From 6d29e6c93bdcf6cf31fcfa133258257704945dfc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 26 May 2021 23:28:02 -0500 Subject: [PATCH 106/608] Remove black python 3.5 exception block (#365) --- pyproject.toml | 2 +- zeroconf/__init__.py | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a5d30b54..b48e90ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [tool.black] line-length = 110 -target_version = ['py35', 'py36', 'py37'] +target_version = ['py35', 'py36', 'py37', 'py38'] skip_string_normalization = true diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ae4344a7..223d1569 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1762,9 +1762,6 @@ class ServiceInfo(RecordUpdateListener): text = b'' - # FIXME(dtantsur): black 19.3b0 produces code that is not valid syntax on - # Python 3.5: https://github.com/python/black/issues/759 - # fmt: off def __init__( self, type_: str, @@ -1795,11 +1792,12 @@ def __init__( else: self._addresses = [] # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses - if not isinstance(a, bytes) or len(a) not in (4, 16)] + invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] if invalid: - raise TypeError('Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid) + raise TypeError( + 'Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid + ) self.port = port self.weight = weight self.priority = priority @@ -1812,7 +1810,6 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl - # fmt: on @property def addresses(self) -> List[bytes]: From bae3a9b97672581e77255c4937b815173c8547b4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 26 May 2021 23:54:42 -0500 Subject: [PATCH 107/608] Ensure ServiceInfo populates all AAAA records (#366) - Use get_all_by_details to ensure all records are loaded into addresses. - Only load A/AAAA records from cache once in load_from_cache if there is a SRV record present - Move duplicate code that checked if the ServiceInfo was complete into its own function --- zeroconf/__init__.py | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 223d1569..fac15b0f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1929,27 +1929,38 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) self.port = record.port self.weight = record.weight self.priority = record.priority - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN)) - self.update_record(zc, now, zc.cache.get_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) + self._update_addresses_from_cache(zc, now) elif isinstance(record, DNSText): if record.key == self.key: self._set_text(record.text) + def _update_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: + """Update the address records from the cache.""" + cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) + if cached_a_record: + self.update_record(zc, now, cached_a_record) + for cached_aaaa_record in zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN): + self.update_record(zc, now, cached_aaaa_record) + def load_from_cache(self, zc: 'Zeroconf') -> bool: """Populate the service info from the cache.""" now = current_time_millis() - record_types_for_check_cache = [(_TYPE_SRV, _CLASS_IN), (_TYPE_TXT, _CLASS_IN)] - if self.server is not None: - record_types_for_check_cache.append((_TYPE_A, _CLASS_IN)) - record_types_for_check_cache.append((_TYPE_AAAA, _CLASS_IN)) - for record_type in record_types_for_check_cache: - cached = zc.cache.get_by_details(self.name, *record_type) - if cached: - self.update_record(zc, now, cached) - - if self.server is not None and self.text is not None and self._addresses: - return True - return False + cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + # If there is a srv record, A and AAAA will already + # be called and we do not want to do it twice + self.update_record(zc, now, cached_srv_record) + elif self.server is not None: + self._update_addresses_from_cache(zc, now) + cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + self.update_record(zc, now, cached_txt_record) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return not (self.server is None or self.text is None or not self._addresses) def request(self, zc: 'Zeroconf', timeout: float) -> bool: """Returns true if the service could be discovered on the @@ -1964,7 +1975,7 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: last = now + timeout try: zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) - while self.server is None or self.text is None or not self._addresses: + while not self._is_complete: if last <= now: return False if next_ <= now: From 5a4c1e46510956276de117d86bee9d2ccb602802 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 08:45:17 -0500 Subject: [PATCH 108/608] Fix empty answers being added in ServiceInfo.request (#367) --- zeroconf/__init__.py | 48 +++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index fac15b0f..1e195532 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1014,6 +1014,29 @@ def add_additional_answer(self, record: DNSRecord) -> None: """ self.additionals.append(record) + def add_question_or_one_cache( + self, zc: "Zeroconf", now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached.""" + cached_entry = zc.cache.get_by_details(name, type_, class_) + if not cached_entry: + self.add_question(DNSQuestion(name, type_, class_)) + else: + self.add_answer_at_time(cached_entry, now) + + def add_question_or_all_cache( + self, zc: "Zeroconf", now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached. + This is currently only used for IPv6 addresses. + """ + cached_entries = zc.cache.get_all_by_details(name, type_, class_) + if not cached_entries: + self.add_question(DNSQuestion(name, type_, class_)) + return + for cached_entry in cached_entries: + self.add_answer_at_time(cached_entry, now) + def pack(self, format_: Union[bytes, str], value: Any) -> None: self.data.append(struct.pack(format_, value)) self.size += struct.calcsize(format_) @@ -1974,30 +1997,19 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: next_ = now last = now + timeout try: - zc.add_listener(self, DNSQuestion(self.name, _TYPE_ANY, _CLASS_IN)) + # Do not set a question on the listener to preload from cache + # since we just checked it above in load_from_cache + zc.add_listener(self, None) while not self._is_complete: if last <= now: return False if next_ <= now: out = DNSOutgoing(_FLAGS_QR_QUERY) - cached_entry = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) - if not cached_entry: - out.add_question(DNSQuestion(self.name, _TYPE_SRV, _CLASS_IN)) - out.add_answer_at_time(cached_entry, now) - cached_entry = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) - if not cached_entry: - out.add_question(DNSQuestion(self.name, _TYPE_TXT, _CLASS_IN)) - out.add_answer_at_time(cached_entry, now) - + out.add_question_or_one_cache(zc, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc, now, self.name, _TYPE_TXT, _CLASS_IN) if self.server is not None: - cached_entry = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) - if not cached_entry: - out.add_question(DNSQuestion(self.server, _TYPE_A, _CLASS_IN)) - out.add_answer_at_time(cached_entry, now) - cached_entry = zc.cache.get_by_details(self.name, _TYPE_AAAA, _CLASS_IN) - if not cached_entry: - out.add_question(DNSQuestion(self.server, _TYPE_AAAA, _CLASS_IN)) - out.add_answer_at_time(cached_entry, now) + out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) zc.send(out) next_ = now + delay delay *= 2 From 4657a773690a34c897c80894a10ac33b6edadf8b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 09:03:32 -0500 Subject: [PATCH 109/608] Reduce complexity of ServiceBrowser enqueue_callback (#368) - The handler key was by name, however ServiceBrowser can have multiple types which meant the check to see if a state change was an add remove, or update was overly complex. Reduce the complexity by making the key (name, type_) --- zeroconf/__init__.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1e195532..b24ca332 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1574,7 +1574,7 @@ def __init__( ) -> None: """Creates a browser for a specific type""" assert handlers or listener, 'You need to specify at least one handler' - self.types = set(type_ if isinstance(type_, list) else [type_]) + self.types = set(type_ if isinstance(type_, list) else [type_]) # type: Set[str] for check_type_ in self.types: if not check_type_.endswith(service_type_name(check_type_, strict=False)): raise BadTypeInNameException @@ -1590,7 +1590,7 @@ def __init__( current_time = current_time_millis() self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} - self._handlers_to_call = OrderedDict() # type: OrderedDict[str, Tuple[str, ServiceStateChange]] + self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] self._service_state_changed = Signal() @@ -1655,20 +1655,16 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> # Code to ensure we only do a single update message # Precedence is; Added, Remove, Update - + key = (name, type_) if ( state_change is ServiceStateChange.Added or ( state_change is ServiceStateChange.Removed - and ( - self._handlers_to_call.get(name) is ServiceStateChange.Updated - or self._handlers_to_call.get(name) is ServiceStateChange.Added - or self._handlers_to_call.get(name) is None - ) + and self._handlers_to_call.get(key) != ServiceStateChange.Added ) - or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call) + or (state_change is ServiceStateChange.Updated and key not in self._handlers_to_call) ): - self._handlers_to_call[name] = (type_, state_change) + self._handlers_to_call[key] = state_change if record.type == _TYPE_PTR and record.name in self.types: assert isinstance(record, DNSPointer) @@ -1753,12 +1749,12 @@ def run(self) -> None: if len(self._handlers_to_call) > 0 and not self.zc.done: with self.zc._handlers_lock: - (name, service_type_state_change) = self._handlers_to_call.popitem(False) + (name_type, state_change) = self._handlers_to_call.popitem(False) self._service_state_changed.fire( zeroconf=self.zc, - service_type=service_type_state_change[0], - name=name, - state_change=service_type_state_change[1], + service_type=name_type[1], + name=name_type[0], + state_change=state_change, ) From 4819ef8c97ddbbadcd6e7cf1b5fee36f573bde45 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 09:22:15 -0500 Subject: [PATCH 110/608] Abstract check to see if a record matches a type the ServiceBrowser wants (#369) --- zeroconf/__init__.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b24ca332..5559f9db 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1642,6 +1642,10 @@ def on_change( def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface + def _record_matching_type(self, record: DNSRecord) -> Optional[str]: + """Return the type if the record matches one of the types we are browsing.""" + return next((type_ for type_ in self.types if record.name.endswith(type_)), None) + def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: """Callback invoked by Zeroconf when new information arrives. @@ -1707,14 +1711,14 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> # Iterate through the DNSCache and callback any services that use this address for service in self.zc.cache.entries_with_server(record.name): - for type_ in self.types: - if service.name.endswith(type_): - enqueue_callback(ServiceStateChange.Updated, type_, service.name) + type_ = self._record_matching_type(service) + if type_: + enqueue_callback(ServiceStateChange.Updated, type_, service.name) elif not record.is_expired(now): - for type_ in self.types: - if record.name.endswith(type_): - enqueue_callback(ServiceStateChange.Updated, type_, record.name) + type_ = self._record_matching_type(record) + if type_: + enqueue_callback(ServiceStateChange.Updated, type_, record.name) def cancel(self) -> None: self.done = True From 7f45bef8db444b0436c5f80b4f4b31b2f1d7ec2f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 09:34:29 -0500 Subject: [PATCH 111/608] Remove Callable quoting (#371) - The current minimum supported cpython is 3.6+ which does not need the quoting --- zeroconf/__init__.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5559f9db..83431319 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1521,17 +1521,15 @@ def registration_interface(self) -> 'SignalRegistrationInterface': return SignalRegistrationInterface(self._handlers) -# NOTE: Callable quoting needed on Python 3.5.2, see -# https://github.com/jstasiak/python-zeroconf/issues/208 for details. class SignalRegistrationInterface: - def __init__(self, handlers: List['Callable[..., None]']) -> None: + def __init__(self, handlers: List[Callable[..., None]]) -> None: self._handlers = handlers - def register_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface': + def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': self._handlers.append(handler) return self - def unregister_handler(self, handler: 'Callable[..., None]') -> 'SignalRegistrationInterface': + def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': self._handlers.remove(handler) return self @@ -1564,9 +1562,7 @@ def __init__( self, zc: 'Zeroconf', type_: Union[str, list], - # NOTE: Callable quoting needed on Python 3.5.2, see - # https://github.com/jstasiak/python-zeroconf/issues/208 for details. - handlers: Optional[Union[ServiceListener, List['Callable[..., None]']]] = None, + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, listener: Optional[ServiceListener] = None, addr: Optional[str] = None, port: int = _MDNS_PORT, @@ -1600,9 +1596,7 @@ def __init__( listener = cast(ServiceListener, handlers) handlers = None - # NOTE: Callable quoting needed on Python 3.5.2, see - # https://github.com/jstasiak/python-zeroconf/issues/208 for details. - handlers = cast(List['Callable[..., None]'], handlers or []) + handlers = cast(List[Callable[..., None]], handlers or []) if listener: From 82fb26f14518a8e59f886b8d7b0708a68725bf48 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 09:39:59 -0500 Subject: [PATCH 112/608] Update changelog for 0.32.0 (unreleased) (#372) --- README.rst | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/README.rst b/README.rst index 956eab05..56a5aff9 100644 --- a/README.rst +++ b/README.rst @@ -134,6 +134,68 @@ See examples directory for more. Changelog ========= +0.32.0 (Unreleased) +=================== + +* Remove Callable quoting (#371) @bdraco + +* Abstract check to see if a record matches a type the ServiceBrowser wants (#369) @bdraco + +* Reduce complexity of ServiceBrowser enqueue_callback (#368) @bdraco + +* Fix empty answers being added in ServiceInfo.request (#367) @bdraco + +* Ensure ServiceInfo populates all AAAA records (#366) @bdraco + + Use get_all_by_details to ensure all records are loaded + into addresses. + + Only load A/AAAA records from cache once in load_from_cache + if there is a SRV record present + + Move duplicate code that checked if the ServiceInfo was complete + into its own function + +* Remove black python 3.5 exception block (#365) @bdraco + +* Small cleanup of ServiceInfo.update_record (#364) @bdraco + +* Add new cache function get_all_by_details (#363) @bdraco + When working with IPv6, multiple AAAA records can exist + for a given host. get_by_details would only return the + latest record in the cache. + + Fix a case where the cache list can change during + iteration + +* Small cleanups to asyncio tests (#362) @bdraco + +* Improve test coverage for name conflicts (#357) @bdraco + +* Return task objects created by AsyncZeroconf (#360) @nocarryr + +0.31.0 +====== + +* Separated cache loading from I/O in ServiceInfo and fixed cache lookup (#356), + thanks to J. Nick Koston. + + The ServiceInfo class gained a load_from_cache() method to only fetch information + from Zeroconf cache (if it exists) with no IO performed. Additionally this should + reduce IO in cases where cache lookups were previously incorrectly failing. + +0.30.0 +====== + +* Some nice refactoring work including removal of the Reaper thread, + thanks to J. Nick Koston. + +* Fixed a Windows-specific The requested address is not valid in its context regression, + thanks to Timothee ‘TTimo’ Besset and J. Nick Koston. + +* Provided an asyncio-compatible service registration layer (in the zeroconf.asyncio module), + thanks to J. Nick Koston. + 0.29.0 ====== From 5d4aa2800d1196274cfdd0bf3e631f49ab5b78bd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 09:57:34 -0500 Subject: [PATCH 113/608] Reduce length of ServiceBrowser thread name with many types (#373) - Before "zeroconf-ServiceBrowser__ssh._tcp.local.-_enphase-envoy._tcp.local.-_hap._udp.local." "-_nut._tcp.local.-_Volumio._tcp.local.-_kizbox._tcp.local.-_home-assistant._tcp.local." "-_viziocast._tcp.local.-_dvl-deviceapi._tcp.local.-_ipp._tcp.local.-_touch-able._tcp.local." "-_hap._tcp.local.-_system-bridge._udp.local.-_dkapi._tcp.local.-_airplay._tcp.local." "-_elg._tcp.local.-_miio._udp.local.-_wled._tcp.local.-_esphomelib._tcp.local." "-_ipps._tcp.local.-_fbx-api._tcp.local.-_xbmc-jsonrpc-h._tcp.local.-_powerview._tcp.local." "-_spotify-connect._tcp.local.-_leap._tcp.local.-_api._udp.local.-_plugwise._tcp.local." "-_googlecast._tcp.local.-_printer._tcp.local.-_axis-video._tcp.local.-_http._tcp.local." "-_mediaremotetv._tcp.local.-_homekit._tcp.local.-_bond._tcp.local.-_daap._tcp.local._243" - After "zeroconf-ServiceBrowser-_miio._udp-_mediaremotetv._tcp-_dvl-deviceapi._tcp-_ipp._tcp" "-_dkapi._tcp-_hap._udp-_xbmc-jsonrpc-h._tcp-_hap._tcp-_googlecast._tcp-_airplay._tcp" "-_viziocast._tcp-_api._udp-_kizbox._tcp-_spotify-connect._tcp-_home-assistant._tcp" "-_bond._tcp-_powerview._tcp-_daap._tcp-_http._tcp-_leap._tcp-_elg._tcp-_homekit._tcp" "-_ipps._tcp-_plugwise._tcp-_ssh._tcp-_esphomelib._tcp-_Volumio._tcp-_fbx-api._tcp" "-_wled._tcp-_touch-able._tcp-_enphase-envoy._tcp-_axis-video._tcp-_printer._tcp" "-_system-bridge._udp-_nut._tcp-244" --- zeroconf/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 83431319..815300b9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1627,8 +1627,8 @@ def on_change( self.service_state_changed.register_handler(h) self.start() - self.name = "zeroconf-ServiceBrowser_%s_%s" % ( - '-'.join(self.types), + self.name = "zeroconf-ServiceBrowser-%s-%s" % ( + '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), ) From 03f2eb688859a78807305771d04b216e20e72064 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 10:16:19 -0500 Subject: [PATCH 114/608] Fix RFC6762 Section 10.2 paragraph 2 compliance (#374) --- zeroconf/__init__.py | 16 ++++++---------- zeroconf/test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 815300b9..3d3fefc1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2761,18 +2761,14 @@ def handle_response(self, msg: DNSIncoming) -> None: updated = True if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # Since the cache format is keyed on the lower case record name - # we can avoid iterating everything in the cache and - # only look though entries for the specific name. - # entries_with_name will take care of converting to lowercase - for entry in self.cache.entries_with_name(record.name): - + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): if entry == record: updated = False - - # Check the time first because it is far cheaper - # than the __eq__ - if (record.created - entry.created > 1000) and DNSEntry.__eq__(entry, record): + if record.created - entry.created > 1000 and entry not in msg.answers: self.cache.remove(entry) expired = record.is_expired(now) diff --git a/zeroconf/test.py b/zeroconf/test.py index edc6f55f..28c7482b 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -596,6 +596,28 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi return r.DNSIncoming(generated.packet()) + def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + """Mock an incoming message for the case where the packet is split.""" + ttl = 120 + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSService( + service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + ), + 0, + ) + return r.DNSIncoming(generated.packet()) + service_name = 'name._type._tcp.local.' service_type = '_type._tcp.local.' service_server = 'ash-2.local.' @@ -630,6 +652,14 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi time.sleep(1.1) + # The split message only has a SRV and A record. + # This should not evict TXT records from the cache + zeroconf.handle_response(mock_split_incoming_msg(r.ServiceStateChange.Updated)) + time.sleep(1.1) + dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert dns_text is not None + assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' + # service removed zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) From 51337425c9be08d59d496c6783d07d5e4e2382d4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 11:08:21 -0500 Subject: [PATCH 115/608] Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) --- zeroconf/__init__.py | 66 ++++++++++++++++++++++++++++++-------------- zeroconf/test.py | 22 ++++++++++++++- 2 files changed, 66 insertions(+), 22 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3d3fefc1..25d62814 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1691,16 +1691,12 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> if record.is_expired(now): return - address_changed = False - for service in zc.cache.entries_with_name(record.name): - if isinstance(service, DNSAddress) and service.address != record.address: - address_changed = True - break - - # Avoid iterating the entire DNSCache if the address has not changed - # as this is an expensive operation when there many hosts - # generating zeroconf traffic. - if not address_changed: + # Only trigger an updated event if the address is new + if record.address in set( + service.address + for service in zc.cache.entries_with_name(record.name) + if isinstance(service, DNSAddress) + ): return # Iterate through the DNSCache and callback any services that use this address @@ -2754,7 +2750,10 @@ def update_record(self, now: float, rec: DNSRecord) -> None: def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" - updates = [] # type: List[Tuple[float, DNSRecord, Optional[DNSRecord]]] + updates = [] # type: List[DNSRecord] + address_adds = [] # type: List[DNSAddress] + other_adds = [] # type: List[DNSRecord] + removes = [] # type: List[DNSRecord] now = current_time_millis() for record in msg.answers: @@ -2769,7 +2768,7 @@ def handle_response(self, msg: DNSIncoming) -> None: if entry == record: updated = False if record.created - entry.created > 1000 and entry not in msg.answers: - self.cache.remove(entry) + removes.append(entry) expired = record.is_expired(now) maybe_entry = self.cache.get(record) @@ -2777,22 +2776,47 @@ def handle_response(self, msg: DNSIncoming) -> None: if maybe_entry is not None: maybe_entry.reset_ttl(record) else: - self.cache.add(record) + if isinstance(record, DNSAddress): + address_adds.append(record) + else: + other_adds.append(record) if updated: - updates.append((now, record, None)) + updates.append(record) elif maybe_entry is not None: - updates.append((now, record, maybe_entry)) + updates.append(record) + removes.append(record) - if not updates: + if not updates and not address_adds and not other_adds and not removes: return # Only hold the lock if we have updates with self._handlers_lock: - for update in updates: - now, record, entry_to_remove = update - self.update_record(update[0], update[1]) - if entry_to_remove: - self.cache.remove(entry_to_remove) + for record in updates: + self.update_record(now, record) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the lock to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + for record in address_adds: + self.cache.add(record) + for record in other_adds: + self.cache.add(record) + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + for record in removes: + self.cache.remove(record) def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if diff --git a/zeroconf/test.py b/zeroconf/test.py index 28c7482b..a3a3e7ad 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1333,12 +1333,14 @@ def update_service(self, zeroconf, type, name): class TestServiceBrowser(unittest.TestCase): def test_update_record(self): + enable_ipv6 = socket.has_ipv6 and not os.environ.get('SKIP_IPV6') service_name = 'name._type._tcp.local.' service_type = '_type._tcp.local.' service_server = 'ash-1.local.' service_text = b'path=/~matt1/' service_address = '10.0.1.2' + service_v6_address = "2001:db8::1" service_added_count = 0 service_removed_count = 0 @@ -1362,7 +1364,11 @@ def update_service(self, zc, type_, name) -> None: nonlocal service_updated_count service_updated_count += 1 service_info = zc.get_service_info(type_, name) - assert service_info.addresses[0] == socket.inet_aton(service_address) + assert socket.inet_aton(service_address) in service_info.addresses + if enable_ipv6: + assert socket.inet_pton( + socket.AF_INET6, service_v6_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) assert service_info.text == service_text assert service_info.server == service_server service_updated_event.set() @@ -1387,6 +1393,20 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi 0, ) + # Send the IPv6 address first since we previously + # had a bug where the IPv4 would be missing if the + # IPv6 was seen first + if enable_ipv6: + generated.add_answer_at_time( + r.DNSAddress( + service_server, + r._TYPE_AAAA, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_address), + ), + 0, + ) generated.add_answer_at_time( r.DNSAddress( service_server, From b158b1cff31620d5cf27969e475d788332f4b38c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 11:28:44 -0500 Subject: [PATCH 116/608] Ensure duplicate packets do not trigger duplicate updates (#376) - If TXT or SRV records update was already processed and then recieved again, it was possible for a second update to be called back in the ServiceBrowser --- zeroconf/__init__.py | 27 ++++++++++++++++----------- zeroconf/test.py | 9 +++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 25d62814..e2c291e9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1664,9 +1664,11 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> ): self._handlers_to_call[key] = state_change - if record.type == _TYPE_PTR and record.name in self.types: - assert isinstance(record, DNSPointer) - expired = record.is_expired(now) + expired = record.is_expired(now) + + if isinstance(record, DNSPointer): + if record.name not in self.types: + return service_key = record.alias.lower() try: old_record = self._services[record.name][service_key] @@ -1685,12 +1687,13 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) if expires < self._next_time[record.name]: self._next_time[record.name] = expires + return - elif record.type == _TYPE_A or record.type == _TYPE_AAAA: - assert isinstance(record, DNSAddress) - if record.is_expired(now): - return + # If its expired or already exists in the cache it cannot be updated. + if expired or self.zc.cache.get(record): + return + if isinstance(record, DNSAddress): # Only trigger an updated event if the address is new if record.address in set( service.address @@ -1704,11 +1707,13 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> type_ = self._record_matching_type(service) if type_: enqueue_callback(ServiceStateChange.Updated, type_, service.name) + break + + return - elif not record.is_expired(now): - type_ = self._record_matching_type(record) - if type_: - enqueue_callback(ServiceStateChange.Updated, type_, record.name) + type_ = self._record_matching_type(record) + if type_: + enqueue_callback(ServiceStateChange.Updated, type_, record.name) def cancel(self) -> None: self.done = True diff --git a/zeroconf/test.py b/zeroconf/test.py index a3a3e7ad..66ee453a 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1455,6 +1455,15 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi assert service_updated_count == 2 assert service_removed_count == 0 + # service TXT updated - duplicate update should not trigger another service_updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + # service A updated service_updated_event.clear() service_address = '10.0.1.3' From 5535ea8c365557681721fdafdcabfc342c75daf5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 11:35:33 -0500 Subject: [PATCH 117/608] Update changelog with latest merges (#377) --- README.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.rst b/README.rst index 56a5aff9..3d6615e8 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,18 @@ Changelog 0.32.0 (Unreleased) =================== +* Ensure duplicate packets do not trigger duplicate updates (#376) + + If TXT or SRV records update was already processed and then + recieved again, it was possible for a second update to be + called back in the ServiceBrowser + +* Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) + +* Fix RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco + +* Reduce length of ServiceBrowser thread name with many types (#373) @bdraco + * Remove Callable quoting (#371) @bdraco * Abstract check to see if a record matches a type the ServiceBrowser wants (#369) @bdraco From 23442d2e5a0336a64646cb70f2ce389746744ce0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 11:36:40 -0500 Subject: [PATCH 118/608] Bump version to 0.31.0 to match released version (#378) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e2c291e9..1484ec66 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -41,7 +41,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.29.0' +__version__ = '0.31.0' __license__ = 'LGPL' From 60c1895e67a6147ab8c6ba7d21d4fe5adec3e590 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 11:58:40 -0500 Subject: [PATCH 119/608] Coalesce browser questions scheduled at the same time (#379) - With multiple types, the ServiceBrowser questions can be chatty because it would generate a question packet for each type. If multiple types are due to be requested, try to combine the questions into a single outgoing packet(s) --- zeroconf/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1484ec66..ed3cfd01 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1733,19 +1733,22 @@ def run(self) -> None: if self.zc.done or self.done: return now = current_time_millis() + out = None for type_ in self.types: if self._next_time[type_] > now: continue - out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) + if not out: + out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) for record in self._services[type_].values(): if not record.is_stale(now): out.add_answer_at_time(record, now) - - self.zc.send(out, addr=self.addr, port=self.port) self._next_time[type_] = now + self._delay[type_] self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + if out: + self.zc.send(out, addr=self.addr, port=self.port) + if len(self._handlers_to_call) > 0 and not self.zc.done: with self.zc._handlers_lock: (name_type, state_change) = self._handlers_to_call.popitem(False) From 3afa5c13f2be956505428c5b01f6ce507845131a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 12:16:09 -0500 Subject: [PATCH 120/608] Complete ServiceInfo request as soon as all questions are answered (#380) - Closes a small race condition where there were no questions to ask because the cache was populated in between checks --- zeroconf/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ed3cfd01..840ed423 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2008,6 +2008,8 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if self.server is not None: out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) + if not out.questions: + return True zc.send(out) next_ = now + delay delay *= 2 From 2b502bc2e21efa2f840c42ed79f850b276a8c103 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 12:18:07 -0500 Subject: [PATCH 121/608] Update changelog with latest merges (#381) --- README.rst | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 3d6615e8..8a8e54b1 100644 --- a/README.rst +++ b/README.rst @@ -137,13 +137,20 @@ Changelog 0.32.0 (Unreleased) =================== -* Ensure duplicate packets do not trigger duplicate updates (#376) +* Complete ServiceInfo request as soon as all questions are answered (#380) @bdraco + + Closes a small race condition where there were no questions + to ask because the cache was populated in between checks + +* Coalesce browser questions scheduled at the same time (#379) @bdraco + +* Ensure duplicate packets do not trigger duplicate updates (#376) @bdraco If TXT or SRV records update was already processed and then recieved again, it was possible for a second update to be called back in the ServiceBrowser -* Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) +* Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) @bdraco * Fix RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco From 69a79b9fd48a24d311520e228c78b2aae52d1dd5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 13:27:42 -0500 Subject: [PATCH 122/608] Fix multiple unclosed instances in tests (#383) --- zeroconf/test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index 66ee453a..899cf24f 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -37,6 +37,15 @@ original_logging_level = logging.NOTSET +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads + + def setup_module(): global original_logging_level original_logging_level = log.level @@ -924,6 +933,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): zc.unregister_service(info) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 + zc.close() def test_name_conflicts(self): # instantiate a zeroconf instance @@ -952,6 +962,7 @@ def test_name_conflicts(self): ) with pytest.raises(r.NonUniqueNameException): zc.register_service(conflicting_info) + zc.close() class TestServiceRegistry(unittest.TestCase): @@ -1598,6 +1609,7 @@ def test_service_info_rejects_non_matching_updates(self): ), ) assert new_address not in info.addresses + zc.close() def test_get_info_partial(self): @@ -2188,6 +2200,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): # unregister zc.unregister_service(info) + zc.close() def test_dns_compression_rollback_for_corruption(): From 5057f97b9b724c041d2bee65972fe3637bf04f0b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 14:06:52 -0500 Subject: [PATCH 123/608] Ensure the cache is checked for name conflict after final service query with asyncio (#382) - The check was not happening after the last query --- zeroconf/asyncio.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 4ba7f312..a07acb9b 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -106,13 +106,17 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b instance_name_from_service_info(info) if cooperating_responders: return + self._raise_on_name_conflict(info) for i in range(3): - # check for a name conflict - if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): - raise NonUniqueNameException if i != 0: await asyncio.sleep(_CHECK_TIME / 1000) await self.loop.run_in_executor(None, self.zeroconf.send_service_query, info) + self._raise_on_name_conflict(info) + + def _raise_on_name_conflict(self, info: ServiceInfo) -> None: + """Raise NonUniqueNameException if the ServiceInfo has a conflict.""" + if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): + raise NonUniqueNameException async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service. From 69d9357b3dae7a99d302bf4ad71d4ed45cbe3e42 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 14:19:35 -0500 Subject: [PATCH 124/608] Update changelog with latest commits (#384) --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index 8a8e54b1..a2c9b1f8 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,8 @@ Changelog 0.32.0 (Unreleased) =================== +* Ensure the cache is checked for name conflict after final service query with asyncio (#382) @bdraco + * Complete ServiceInfo request as soon as all questions are answered (#380) @bdraco Closes a small race condition where there were no questions From 62a02d774fd874340fa3043bd3bf260a77ffe3d8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 22:17:36 -0500 Subject: [PATCH 125/608] Ensure listeners do not miss initial packets if Engine starts too quickly (#387) --- zeroconf/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 840ed423..beca9f86 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1373,7 +1373,6 @@ def __init__(self, zc: 'Zeroconf') -> None: self.condition = threading.Condition() self.socketpair = socket.socketpair() self._last_cache_cleanup = 0.0 - self.start() self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) def run(self) -> None: @@ -2539,6 +2538,10 @@ def __init__( if self.multi_socket: for s in self._respond_sockets: self.engine.add_reader(self.listener, s) + # Start the engine only after all + # the readers have been added to avoid + # missing any packets that are on the wire + self.engine.start() @property def done(self) -> bool: From 709bd9abae63cf566220693501cd37cf74391ccf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 22:28:51 -0500 Subject: [PATCH 126/608] Simplify DNSPointer processing in ServiceBrowser (#386) --- zeroconf/__init__.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index beca9f86..dbb3896c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1669,23 +1669,19 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> if record.name not in self.types: return service_key = record.alias.lower() - try: - old_record = self._services[record.name][service_key] - except KeyError: - if not expired: - self._services[record.name][service_key] = record - enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + services_by_type = self._services[record.name] + old_record = services_by_type.get(service_key) + if old_record is None: + services_by_type[service_key] = record + enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + elif expired: + del services_by_type[service_key] + enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: - if not expired: - old_record.reset_ttl(record) - else: - del self._services[record.name][service_key] - enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) - return - - expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) - if expires < self._next_time[record.name]: - self._next_time[record.name] = expires + old_record.reset_ttl(record) + expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) + if expires < self._next_time[record.name]: + self._next_time[record.name] = expires return # If its expired or already exists in the cache it cannot be updated. From ba8d8e3e658c71e0d603db3f4c5bdfe8e508710a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 22:47:09 -0500 Subject: [PATCH 127/608] Fix flapping test: test_update_record (#388) --- zeroconf/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 899cf24f..de95a0cf 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1898,7 +1898,7 @@ def _mock_get_expiration_time(self, percent): zeroconf.handle_response( mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) ) - service_add_event.wait(wait_time) + service_add_event.wait(wait_time) assert called_with_refresh_time_check is True assert service_added_count == 3 assert service_removed_count == 0 From 8f4d2e858a5efadeb33120322c1169f3ce7d6e0c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 22:56:14 -0500 Subject: [PATCH 128/608] Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) --- zeroconf/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dbb3896c..4bb565bc 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2083,11 +2083,11 @@ def find( # wait for responses time.sleep(timeout) + browser.cancel() + # close down anything we opened if zc is None: local_zc.close() - else: - browser.cancel() return tuple(sorted(listener.found_services)) From 33a3a6ae42ef8c4ea0f606ad2a02df3f6bc13752 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 27 May 2021 22:59:31 -0500 Subject: [PATCH 129/608] Update changelog for 0.32.0 (Unreleased) (#390) --- README.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.rst b/README.rst index a2c9b1f8..ffdeb0c0 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,23 @@ Changelog 0.32.0 (Unreleased) =================== +* Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco + + There was a short window where the ServiceBrowser thread + could be left running after Zeroconf is closed because + the .join() was never waited for when a new Zeroconf + object was created + +* Simplify DNSPointer processing in ServiceBrowser (#386) @bdraco + +* Breaking change: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco + + When manually creating a zeroconf.Engine object, it is no longer started automatically. + It must manually be started by calling .start() on the created object. + + The Engine thread is now started after all the listeners have been added to avoid a + race condition where packets could be missed at startup. + * Ensure the cache is checked for name conflict after final service query with asyncio (#382) @bdraco * Complete ServiceInfo request as soon as all questions are answered (#380) @bdraco From d67d5f41effff4c01735de0ae64ed25a5dbe7567 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 28 May 2021 10:27:00 -0500 Subject: [PATCH 130/608] Fix IPv6 setup under MacOS when binding to "" (#392) - Setting IP_MULTICAST_TTL and IP_MULTICAST_LOOP does not work under MacOS when the bind address is "" --- zeroconf/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 4bb565bc..e90e90fa 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2235,8 +2235,12 @@ def new_socket( if ip_version != IPVersion.V6Only: # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and # IP_MULTICAST_LOOP socket options as an unsigned char. - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + try: + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + except socket.error as e: + if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS + raise if ip_version != IPVersion.V4Only: # However, char doesn't work here (at least on Linux) s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) From ec2fafd904cd2d341a3815fcf6d34508dcddda5a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 28 May 2021 10:40:59 -0500 Subject: [PATCH 131/608] Enable IPv6 in the CI (#393) --- .github/workflows/ci.yml | 2 -- zeroconf/test.py | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86cc95a7..9c6f98f1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,5 @@ jobs: pip install . - name: Run tests run: make ci - env: - SKIP_IPV6: 1 - name: Report coverage to Codecov uses: codecov/codecov-action@v1 diff --git a/zeroconf/test.py b/zeroconf/test.py index de95a0cf..b5e96500 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -16,9 +16,12 @@ import time import unittest import unittest.mock +from functools import lru_cache from threading import Event from typing import Dict, Optional, cast # noqa # used in type hints +import ifaddr + import pytest import zeroconf as r @@ -57,6 +60,28 @@ def teardown_module(): log.setLevel(original_logging_level) +@lru_cache(maxsize=None) +def has_working_ipv6(): + """Return True if if the system can bind an IPv6 address.""" + if not socket.has_ipv6: + return False + + try: + sock = socket.socket(socket.AF_INET6) + sock.bind(('::1', 0)) + except Exception: + return False + finally: + if sock: + sock.close() + + for iface in ifaddr.get_adapters(): + for addr in iface.ips: + if addr.is_IPv6 and iface.index is not None: + return True + return False + + class TestDunder(unittest.TestCase): def test_dns_text_repr(self): # There was an issue on Python 3 that prevented DNSText's repr @@ -550,7 +575,7 @@ def test_close_multiple_times(self): rv.close() rv.close() - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_launch_and_close_v4_v6(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) @@ -558,7 +583,7 @@ def test_launch_and_close_v4_v6(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All) rv.close() - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_launch_and_close_v6_only(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) @@ -1092,7 +1117,7 @@ def test_integration_with_listener(self): finally: zeroconf_registrar.close() - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_v6_records(self): @@ -1124,7 +1149,7 @@ def test_integration_with_listener_v6_records(self): finally: zeroconf_registrar.close() - @unittest.skipIf(not socket.has_ipv6, 'Requires IPv6') + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_ipv6(self): @@ -1240,7 +1265,7 @@ def update_service(self, zeroconf, type, name): desc = {'path': '/~paulsm/'} # type: Dict desc.update(properties) addresses = [socket.inet_aton("10.0.1.2")] - if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) info_service = ServiceInfo( subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses @@ -1344,7 +1369,7 @@ def update_service(self, zeroconf, type, name): class TestServiceBrowser(unittest.TestCase): def test_update_record(self): - enable_ipv6 = socket.has_ipv6 and not os.environ.get('SKIP_IPV6') + enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') service_name = 'name._type._tcp.local.' service_type = '_type._tcp.local.' @@ -2110,7 +2135,7 @@ def test_multiple_addresses(): ) assert info.addresses == [address, address] - if socket.has_ipv6 and not os.environ.get('SKIP_IPV6'): + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) infos = [ From acf174db93ee60f1a80d501eb691d9cb434a90b7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 28 May 2021 10:50:08 -0500 Subject: [PATCH 132/608] Add test coverage for multiple AAAA records (#391) --- zeroconf/test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index b5e96500..c1c0b775 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1266,6 +1266,7 @@ def update_service(self, zeroconf, type, name): desc.update(properties) addresses = [socket.inet_aton("10.0.1.2")] if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): + addresses.append(socket.inet_pton(socket.AF_INET6, "6001:db8::1")) addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) info_service = ServiceInfo( subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses @@ -1377,6 +1378,7 @@ def test_update_record(self): service_text = b'path=/~matt1/' service_address = '10.0.1.2' service_v6_address = "2001:db8::1" + service_v6_second_address = "6001:db8::1" service_added_count = 0 service_removed_count = 0 @@ -1405,6 +1407,9 @@ def update_service(self, zc, type_, name) -> None: assert socket.inet_pton( socket.AF_INET6, service_v6_address ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert socket.inet_pton( + socket.AF_INET6, service_v6_second_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) assert service_info.text == service_text assert service_info.server == service_server service_updated_event.set() @@ -1443,6 +1448,16 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ), 0, ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + r._TYPE_AAAA, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_second_address), + ), + 0, + ) generated.add_answer_at_time( r.DNSAddress( service_server, From a6010a94b626a9a1585cc47417c08516020729d7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 28 May 2021 10:52:12 -0500 Subject: [PATCH 133/608] Update changelog with latest changes (#394) --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index ffdeb0c0..40b9ef4d 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,8 @@ Changelog 0.32.0 (Unreleased) =================== +* Fix IPv6 setup under MacOS when binding to "" (#392) @bdraco + * Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco There was a short window where the ServiceBrowser thread From dd6383589b161e828def0ed029519a645e434512 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 07:46:15 -0500 Subject: [PATCH 134/608] Remove unreachable code in ServiceInfo (#400) - self.server is never None --- zeroconf/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e90e90fa..3d4e1856 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1966,7 +1966,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: # If there is a srv record, A and AAAA will already # be called and we do not want to do it twice self.update_record(zc, now, cached_srv_record) - elif self.server is not None: + else: self._update_addresses_from_cache(zc, now) cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: @@ -1976,7 +1976,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: @property def _is_complete(self) -> bool: """The ServiceInfo has all expected properties.""" - return not (self.server is None or self.text is None or not self._addresses) + return not (self.text is None or not self._addresses) def request(self, zc: 'Zeroconf', timeout: float) -> bool: """Returns true if the service could be discovered on the From 4ae27beba29c6e9ac1782f40eadda584b4722af7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 08:08:15 -0500 Subject: [PATCH 135/608] Remove unreachable code in ServiceInfo (part 2) (#402) - self.server is never None --- zeroconf/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3d4e1856..266c5f68 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2000,9 +2000,8 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: out = DNSOutgoing(_FLAGS_QR_QUERY) out.add_question_or_one_cache(zc, now, self.name, _TYPE_SRV, _CLASS_IN) out.add_question_or_one_cache(zc, now, self.name, _TYPE_TXT, _CLASS_IN) - if self.server is not None: - out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) + out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) if not out.questions: return True zc.send(out) From bddf69c0839eda966376987a8c4a1fbe3d865529 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 08:21:18 -0500 Subject: [PATCH 136/608] Seperate query generation in ServiceInfo (#401) --- zeroconf/__init__.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 266c5f68..60b3e333 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1997,11 +1997,7 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if last <= now: return False if next_ <= now: - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) + out = self.generate_request_query(zc, now) if not out.questions: return True zc.send(out) @@ -2015,6 +2011,15 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: return True + def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question_or_one_cache(zc, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) + return out + def __eq__(self, other: object) -> bool: """Tests equality of service name""" return isinstance(other, ServiceInfo) and other.name == self.name From e753078f0345fa28ffceb8de69542c8549d2994c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 08:41:31 -0500 Subject: [PATCH 137/608] Seperate query generation for Zeroconf (#403) - Will be used to send the query in asyncio --- zeroconf/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 60b3e333..ae487be6 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2636,16 +2636,24 @@ def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: """Send a broadcast to announce a service.""" + self.send(self.generate_service_broadcast(info, ttl)) + + def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: + """Generate a broadcast to announce a service.""" out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) self._add_broadcast_answer(out, info, ttl) - self.send(out) + return out def send_service_query(self, info: ServiceInfo) -> None: """Send a query to lookup a service.""" + self.send(self.generate_service_query(info)) + + def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: + """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name)) - self.send(out) + return out def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]) -> None: """Add answers to broadcast a service.""" From 1e7b46c36f6e0735b44d3edd9740891a2dc0c761 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 10:34:48 -0500 Subject: [PATCH 138/608] Use a dedicated thread for sending outgoing packets with asyncio (#404) - Sends now go into a queue and are processed by the thread FIFO - Avoids overwhelming the executor when registering multiple services in parallel --- zeroconf/asyncio.py | 57 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index a07acb9b..ac8de3c2 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -20,9 +20,12 @@ USA """ import asyncio +import queue +import threading from typing import Awaitable, Optional from . import ( + DNSOutgoing, IPVersion, InterfaceChoice, InterfacesType, @@ -30,12 +33,48 @@ ServiceInfo, Zeroconf, _CHECK_TIME, + _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME, instance_name_from_service_info, ) +class _AsyncSender(threading.Thread): + """A thread to handle sending DNSOutgoing for asyncio.""" + + def __init__(self, zc: 'Zeroconf'): + """Create the sender thread.""" + super().__init__() + self.zc = zc + self.queue = self._get_queue() + self.start() + self.name = "AsyncZeroconfSender" + + def _get_queue(self) -> queue.Queue: + """Create the best available queue type.""" + if hasattr(queue, "SimpleQueue"): + return queue.SimpleQueue() # type: ignore + return queue.Queue() + + def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + """Queue a send to be processed by the thread.""" + self.queue.put((out, addr, port)) + + def close(self) -> None: + """Close the instance.""" + self.queue.put(None) + self.join() + + def run(self) -> None: + """Runner that processes sends FIFO.""" + while True: + event = self.queue.get() + if event is None: + return + self.zc.send(*event) + + class AsyncZeroconf: """Implementation of Zeroconf Multicast DNS Service Discovery @@ -73,6 +112,7 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) + self.sender = _AsyncSender(self.zeroconf) self.loop = asyncio.get_event_loop() async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: @@ -80,7 +120,7 @@ async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: for i in range(3): if i != 0: await asyncio.sleep(interval / 1000) - await self.loop.run_in_executor(None, self.zeroconf.send_service_broadcast, info, ttl) + self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl)) async def async_register_service( self, @@ -98,7 +138,7 @@ async def async_register_service( and therefore can be awaited if necessary. """ await self.async_check_service(info, cooperating_responders) - await self.loop.run_in_executor(None, self.zeroconf.registry.add, info) + self.zeroconf.registry.add(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: @@ -110,7 +150,7 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b for i in range(3): if i != 0: await asyncio.sleep(_CHECK_TIME / 1000) - await self.loop.run_in_executor(None, self.zeroconf.send_service_query, info) + self.sender.send(self.zeroconf.generate_service_query(info)) self._raise_on_name_conflict(info) def _raise_on_name_conflict(self, info: ServiceInfo) -> None: @@ -124,7 +164,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - await self.loop.run_in_executor(None, self.zeroconf.registry.remove, info) + self.zeroconf.registry.remove(info) return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) async def async_update_service(self, info: ServiceInfo) -> Awaitable: @@ -135,10 +175,15 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - await self.loop.run_in_executor(None, self.zeroconf.registry.update, info) + self.zeroconf.registry.update(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + def _close(self) -> None: + """Shutdown zeroconf and the sender.""" + self.sender.close() + self.zeroconf.close() + async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" - await self.loop.run_in_executor(None, self.zeroconf.close) + await self.loop.run_in_executor(None, self._close) From 2da6198b2e60a598580637e80b3bd579c1f845a5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 10:43:20 -0500 Subject: [PATCH 139/608] Allow passing in a sync Zeroconf instance to AsyncZeroconf (#406) - Uses the same pattern as ZeroconfServiceTypes.find --- zeroconf/asyncio.py | 3 ++- zeroconf/test_asyncio.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index ac8de3c2..460bde3a 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -90,6 +90,7 @@ def __init__( unicast: bool = False, ip_version: Optional[IPVersion] = None, apple_p2p: bool = False, + zc: Optional['Zeroconf'] = None, ) -> None: """Creates an instance of the Zeroconf class, establishing multicast communications, listening and reaping threads. @@ -106,7 +107,7 @@ def __init__( from it. Otherwise defaults to V4 only for backward compatibility. :param apple_p2p: use AWDL interface (only macOS) """ - self.zeroconf = Zeroconf( + self.zeroconf = zc or Zeroconf( interfaces=interfaces, unicast=unicast, ip_version=ip_version, diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index e048a603..a7bc2037 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -27,6 +27,15 @@ async def test_async_basic_usage() -> None: await aiozc.async_close() +@pytest.mark.asyncio +async def test_async_with_sync_passed_in() -> None: + """Test we can create and close the instance when passing in a sync Zeroconf.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(zc=zc) + assert aiozc.zeroconf is zc + await aiozc.async_close() + + @pytest.mark.asyncio async def test_async_service_registration() -> None: """Test registering services broadcasts the registration by default.""" From ff31f386273fbe9fd0b466bbe5f724c815745215 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 11:00:39 -0500 Subject: [PATCH 140/608] Remove unreachable code in ServiceInfo.get_name (#407) --- zeroconf/__init__.py | 4 +--- zeroconf/test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ae487be6..5121503c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1925,9 +1925,7 @@ def _set_text(self, text: bytes) -> None: def get_name(self) -> str: """Name accessor""" - if self.type is not None and self.name.endswith("." + self.type): - return self.name[: len(self.name) - len(self.type) - 1] - return self.name + return self.name[: len(self.name) - len(self.type) - 1] def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: """Updates service information from a DNS record""" diff --git a/zeroconf/test.py b/zeroconf/test.py index c1c0b775..bd246888 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1551,6 +1551,18 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi class TestServiceInfo(unittest.TestCase): + def test_get_name(self): + """Verify the name accessor can strip the type.""" + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + assert info.get_name() == "name" + def test_service_info_rejects_non_matching_updates(self): """Verify records with the wrong name are rejected.""" From 745087b234dd5ff65b4b041a7221d58030a69cdd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 14:15:07 -0500 Subject: [PATCH 141/608] Add support for registering notify listeners (#409) - Notify listeners will be used by AsyncZeroconf to set asyncio.Event objects when new data is received - Registering a notify listener: notify_listener = YourNotifyListener() Use zeroconf.add_notify_listener(notify_listener) - Unregistering a notify listener: Use zeroconf.remove_notify_listener(notify_listener) - Notify listeners must inherit from the NotifyListener class --- zeroconf/__init__.py | 19 +++++++++++++++++++ zeroconf/test.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5121503c..249a4709 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1549,6 +1549,14 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() +class NotifyListener: + """Receive notifications Zeroconf.notify_all is called.""" + + def notify_all(self) -> None: + """Called when Zeroconf.notify_all is called.""" + raise NotImplementedError() + + class ServiceBrowser(RecordUpdateListener, threading.Thread): """Used to browse for a service of a specific type. @@ -2521,6 +2529,7 @@ def __init__( self.multi_socket = unicast or interfaces is not InterfaceChoice.Default self.listeners = [] # type: List[RecordUpdateListener] + self._notify_listeners = [] # type: List[NotifyListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] self.registry = ServiceRegistry() @@ -2559,6 +2568,8 @@ def notify_all(self) -> None: """Notifies all waiting threads""" with self.condition: self.condition.notify_all() + for listener in self._notify_listeners: + listener.notify_all() def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular @@ -2569,6 +2580,14 @@ def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Option return info return None + def add_notify_listener(self, listener: NotifyListener) -> None: + """Adds a listener to receive notify_all events.""" + self._notify_listeners.append(listener) + + def remove_notify_listener(self, listener: NotifyListener) -> None: + """Removes a listener from the set that is currently listening.""" + self._notify_listeners.remove(listener) + def add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object will then have its add_service and remove_service methods called when diff --git a/zeroconf/test.py b/zeroconf/test.py index bd246888..dcfa7c28 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2421,3 +2421,41 @@ def test_add_multicast_member_socket_errors(errno, expected_result): fileno_mock = unittest.mock.PropertyMock(return_value=10) socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result + + +def test_notify_listeners(): + """Test adding and removing notify listeners.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + notify_called = 0 + + class TestNotifyListener(r.NotifyListener): + def notify_all(self): + nonlocal notify_called + notify_called += 1 + + with pytest.raises(NotImplementedError): + r.NotifyListener().notify_all() + + notify_listener = TestNotifyListener() + + zc.add_notify_listener(notify_listener) + + def on_service_state_change(zeroconf, service_type, state_change, name): + """Dummy service callback.""" + + # start a browser + browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) + browser.cancel() + + assert notify_called + zc.remove_notify_listener(notify_listener) + + notify_called = 0 + # start a browser + browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) + browser.cancel() + + assert not notify_called + + zc.close() From 53306e1b99d9133590d47081994ee77cef468828 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 14:33:43 -0500 Subject: [PATCH 142/608] Add async_wait function to AsyncZeroconf (#410) --- zeroconf/asyncio.py | 32 ++++++++++++++++++++++++++++++-- zeroconf/test_asyncio.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 460bde3a..164668eb 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -20,6 +20,7 @@ USA """ import asyncio +import contextlib import queue import threading from typing import Awaitable, Optional @@ -30,6 +31,7 @@ InterfaceChoice, InterfacesType, NonUniqueNameException, + NotifyListener, ServiceInfo, Zeroconf, _CHECK_TIME, @@ -75,6 +77,24 @@ def run(self) -> None: self.zc.send(*event) +class AsyncNotifyListener(NotifyListener): + """A NotifyListener that async code can use to wait for events.""" + + def __init__(self) -> None: + """Create an event for async listeners to wait for.""" + self.event = asyncio.Event() + self.loop = asyncio.get_event_loop() + + def notify_all(self) -> None: + """Schedule an async_notify_all.""" + self.loop.call_soon_threadsafe(self.async_notify_all) + + def async_notify_all(self) -> None: + """Notify all async listeners.""" + self.event.set() + self.event.clear() + + class AsyncZeroconf: """Implementation of Zeroconf Multicast DNS Service Discovery @@ -90,7 +110,7 @@ def __init__( unicast: bool = False, ip_version: Optional[IPVersion] = None, apple_p2p: bool = False, - zc: Optional['Zeroconf'] = None, + zc: Optional[Zeroconf] = None, ) -> None: """Creates an instance of the Zeroconf class, establishing multicast communications, listening and reaping threads. @@ -113,8 +133,10 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) - self.sender = _AsyncSender(self.zeroconf) self.loop = asyncio.get_event_loop() + self.async_notify = AsyncNotifyListener() + self.zeroconf.add_notify_listener(self.async_notify) + self.sender = _AsyncSender(self.zeroconf) async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -182,9 +204,15 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: def _close(self) -> None: """Shutdown zeroconf and the sender.""" self.sender.close() + self.zeroconf.remove_notify_listener(self.async_notify) self.zeroconf.close() async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" await self.loop.run_in_executor(None, self._close) + + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.async_notify.event.wait(), timeout / 1000) diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index a7bc2037..b3d661a3 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -16,6 +16,7 @@ ServiceListener, ServiceNameAlreadyRegistered, Zeroconf, + current_time_millis, ) from .asyncio import AsyncZeroconf @@ -233,3 +234,39 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: ('update', type_, registration_name), ('remove', type_, registration_name), ] + + +@pytest.mark.asyncio +async def test_async_wait_unblocks_on_update() -> None: + """Test async_wait will unblock on update.""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc4-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + + # Should unblock due to update from the + # registration + now = current_time_millis() + await aiozc.async_wait(50000) + assert current_time_millis() - now < 3000 + await task + + now = current_time_millis() + await aiozc.async_wait(50) + assert current_time_millis() - now < 1000 + + await aiozc.async_close() From 0fa049c2e0f5e9f18830583a8df2736630c891e2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 19:24:35 -0500 Subject: [PATCH 143/608] Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) --- zeroconf/asyncio.py | 48 ++++++++++++++++++ zeroconf/test_asyncio.py | 106 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 164668eb..faddb67a 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -35,9 +35,11 @@ ServiceInfo, Zeroconf, _CHECK_TIME, + _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME, + current_time_millis, instance_name_from_service_info, ) @@ -95,6 +97,41 @@ def async_notify_all(self) -> None: self.event.clear() +class AsyncServiceInfo(ServiceInfo): + """An async version of ServiceInfo.""" + + async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(aiozc.zeroconf): + return True + + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + aiozc.zeroconf.add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query(aiozc.zeroconf, now) + if not out.questions: + return self.load_from_cache(aiozc.zeroconf) + aiozc.sender.send(out) + next_ = now + delay + delay *= 2 + + await aiozc.async_wait((min(next_, last) - now) / 1000) + now = current_time_millis() + finally: + aiozc.zeroconf.remove_listener(self) + + return True + + class AsyncZeroconf: """Implementation of Zeroconf Multicast DNS Service Discovery @@ -212,6 +249,17 @@ async def async_close(self) -> None: servicing further queries.""" await self.loop.run_in_executor(None, self._close) + async def async_get_service_info( + self, type_: str, name: str, timeout: int = 3000 + ) -> Optional[AsyncServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = AsyncServiceInfo(type_, name) + if await info.async_request(self, timeout): + return info + return None + async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" with contextlib.suppress(asyncio.TimeoutError): diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index b3d661a3..eebed600 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -6,6 +6,7 @@ import asyncio import socket +import unittest.mock import pytest @@ -16,9 +17,10 @@ ServiceListener, ServiceNameAlreadyRegistered, Zeroconf, + _LISTENER_TIME, current_time_millis, ) -from .asyncio import AsyncZeroconf +from .asyncio import AsyncServiceInfo, AsyncZeroconf @pytest.mark.asyncio @@ -270,3 +272,105 @@ async def test_async_wait_unblocks_on_update() -> None: assert current_time_millis() - now < 1000 await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_info_async_request() -> None: + """Test registering services broadcasts and query with AsyncServceInfo.async_request.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "abc" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name2, type_) + + # Start a tasks BEFORE the registration that will keep trying + # and see the registration a bit later + get_service_info_task1 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + await asyncio.sleep(_LISTENER_TIME / 1000 / 2) + get_service_info_task2 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-1.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-5.local.", + addresses=[socket.inet_aton("10.0.1.5")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + aiosinfo = await get_service_info_task1 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await get_service_info_task2 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3"), socket.inet_pton(socket.AF_INET6, "6001:db8::1")], + ) + + task = await aiozc.async_update_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + aiosinfos = await asyncio.gather( + aiozc.async_get_service_info(type_, registration_name), + aiozc.async_get_service_info(type_, registration_name2), + ) + assert aiosinfos[0] is not None + assert aiosinfos[0].addresses == [socket.inet_aton("10.0.1.3")] + assert aiosinfos[1] is not None + assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")] + + aiosinfo = AsyncServiceInfo(type_, registration_name) + zc_cache = aiozc.zeroconf.cache + for name in zc_cache.names(): + for record in zc_cache.entries_with_name(name): + zc_cache.remove(record) + # Generating the race condition is almost impossible + # without patching since its a TOCTOU race + with unittest.mock.patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): + await aiosinfo.async_request(aiozc, 3000) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + task = await aiozc.async_unregister_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is None + + await aiozc.async_close() From bb83edfbca339fb6ec20b821d79b171220f5e675 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 29 May 2021 19:27:05 -0500 Subject: [PATCH 144/608] Update changelog for 0.32.0 (#411) --- README.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.rst b/README.rst index 40b9ef4d..6a7e04f5 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,14 @@ Changelog 0.32.0 (Unreleased) =================== +* Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco + +* Add support for registering notify listeners (#409) @bdraco + +* Allow passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco + +* Use a dedicated thread for sending outgoing packets with asyncio (#404) @bdraco + * Fix IPv6 setup under MacOS when binding to "" (#392) @bdraco * Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco From 71cfbcb85bdd5948f1b96a871b10e9e35ab76c3b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 31 May 2021 14:54:53 -0500 Subject: [PATCH 145/608] Add async_register_service/async_unregister_service example (#414) --- examples/async_registration.py | 77 ++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 examples/async_registration.py diff --git a/examples/async_registration.py b/examples/async_registration.py new file mode 100644 index 00000000..53d14ce1 --- /dev/null +++ b/examples/async_registration.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +"""Example of announcing 250 services (in this case, a fake HTTP server).""" + +import argparse +import asyncio +import logging +import socket +import time +from typing import List + +from zeroconf import IPVersion +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf + + +async def register_services(infos: List[AsyncServiceInfo]) -> None: + tasks = [aiozc.async_register_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + + +async def unregister_services(infos: List[AsyncServiceInfo]) -> None: + tasks = [aiozc.async_unregister_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + + +async def close_aiozc(aiozc: AsyncZeroconf) -> None: + await aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + infos = [] + for i in range(250): + infos.append( + AsyncServiceInfo( + "_http._tcp.local.", + f"Paul's Test Web Site {i}._http._tcp.local.", + addresses=[socket.inet_aton("127.0.0.1")], + port=80, + properties={'path': '/~paulsm/'}, + server=f"zcdemohost-{i}.local.", + ) + ) + + print("Registration of 250 services, press Ctrl-C to exit...") + aiozc = AsyncZeroconf(ip_version=ip_version) + loop = asyncio.get_event_loop() + loop.run_until_complete(register_services(infos)) + print("Registration complete.") + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + pass + finally: + print("Unregistering...") + loop.run_until_complete(unregister_services(infos)) + print("Unregistration complete.") + loop.run_until_complete(close_aiozc(aiozc)) From 7f08826c03b7997758ff0236834bf6f1a091c558 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 31 May 2021 15:17:14 -0500 Subject: [PATCH 146/608] Add async_request example with browse (#415) --- examples/async_service_info_request.py | 89 ++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/async_service_info_request.py diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py new file mode 100644 index 00000000..c0f953c2 --- /dev/null +++ b/examples/async_service_info_request.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +"""Example of perodic dump of homekit services. + +This example is useful when a user wants an ondemand +list of HomeKit devices on the network. + +""" + +import argparse +import asyncio +import logging +from typing import cast + + +from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf + + +HAP_TYPE = "_hap._tcp.local." + + +async def async_watch_services(aiozc: AsyncZeroconf) -> None: + zeroconf = aiozc.zeroconf + while True: + await asyncio.sleep(5) + infos = [] + for name in zeroconf.cache.names(): + if not name.endswith(HAP_TYPE): + continue + infos.append(AsyncServiceInfo(HAP_TYPE, name)) + tasks = [info.async_request(aiozc, 3000) for info in infos] + await asyncio.gather(*tasks) + for info in infos: + print("Info for %s" % (info.name)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(" Server: %s" % (info.server,)) + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(" %s: %s" % (key, value)) + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + aiozc = AsyncZeroconf(ip_version=ip_version) + + def on_service_state_change( + zeroconf: Zeroconf, service_type: str, state_change: ServiceStateChange, name: str + ) -> None: + """Dummy handler.""" + + print(f"Services with {HAP_TYPE} will be shown every 5s, press Ctrl-C to exit...") + # ServiceBrowser currently is only offered in sync context. + # ServiceInfo has an AsyncServiceInfo counterpart that can be used + # to fetch service info in parallel + browser = ServiceBrowser(aiozc.zeroconf, [HAP_TYPE], handlers=[on_service_state_change]) + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(async_watch_services(aiozc)) + except KeyboardInterrupt: + pass + finally: + browser.cancel() + loop.run_until_complete(aiozc.async_close()) From 58cfcf0c902b5e27937f118bf4f7a855db635301 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 3 Jun 2021 13:51:54 -1000 Subject: [PATCH 147/608] Seperate query generation for ServiceBrowser (#420) --- zeroconf/__init__.py | 71 +++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 27 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 249a4709..73529413 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1723,44 +1723,61 @@ def cancel(self) -> None: self.zc.remove_listener(self) self.join() + def generate_ready_queries(self) -> Optional[DNSOutgoing]: + """Generate the service browser query for any type that is due.""" + out = None + now = current_time_millis() + + if min(self._next_time.values()) > now: + return out + + for type_, due in self._next_time.items(): + if due > now: + continue + + if out is None: + out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) + out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) + + for record in self._services[type_].values(): + if not record.is_stale(now): + out.add_answer_at_time(record, now) + + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + return out + def run(self) -> None: questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] self.zc.add_listener(self, questions) while True: - now = current_time_millis() - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - if len(self._handlers_to_call) == 0 and next_time > now: - self.zc.wait(next_time - now) + if not self._handlers_to_call: + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + if next_time > now: + self.zc.wait(next_time - now) + if self.zc.done or self.done: return - now = current_time_millis() - out = None - for type_ in self.types: - if self._next_time[type_] > now: - continue - if not out: - out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) - out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) - for record in self._services[type_].values(): - if not record.is_stale(now): - out.add_answer_at_time(record, now) - self._next_time[type_] = now + self._delay[type_] - self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + out = self.generate_ready_queries() if out: self.zc.send(out, addr=self.addr, port=self.port) - if len(self._handlers_to_call) > 0 and not self.zc.done: - with self.zc._handlers_lock: - (name_type, state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) + if not self._handlers_to_call: + continue + + with self.zc._handlers_lock: + (name_type, state_change) = self._handlers_to_call.popitem(False) + + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) class ServiceInfo(RecordUpdateListener): From 8bca0305deae0db8ced7e213be3aaee975985c56 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 3 Jun 2021 14:13:46 -1000 Subject: [PATCH 148/608] Seperate logic for consuming records in ServiceInfo (#421) --- zeroconf/__init__.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 73529413..09521937 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1953,13 +1953,26 @@ def get_name(self) -> str: return self.name[: len(self.name) - len(self.type) - 1] def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record""" + """Updates service information from a DNS record.""" if record is None or record.is_expired(now): return + + self._process_record(record, now) + + # Only update addresses if the DNSService (.server) has changed + if not isinstance(record, DNSService): + return + + for record in self._get_address_records_from_cache(zc): + self._process_record(record, now) + + def _process_record(self, record: DNSRecord, now: float) -> None: if isinstance(record, DNSAddress): if record.key == self.server_key and record.address not in self._addresses: self._addresses.append(record.address) - elif isinstance(record, DNSService): + return + + if isinstance(record, DNSService): if record.key != self.key: return self.name = record.name @@ -1968,32 +1981,37 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) self.port = record.port self.weight = record.weight self.priority = record.priority - self._update_addresses_from_cache(zc, now) - elif isinstance(record, DNSText): + return + + if isinstance(record, DNSText): if record.key == self.key: self._set_text(record.text) - def _update_addresses_from_cache(self, zc: 'Zeroconf', now: float) -> None: - """Update the address records from the cache.""" + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: + """Get the address records from the cache.""" + address_records = [] cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) if cached_a_record: - self.update_record(zc, now, cached_a_record) - for cached_aaaa_record in zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN): - self.update_record(zc, now, cached_aaaa_record) + address_records.append(cached_a_record) + address_records.extend(zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) + return address_records def load_from_cache(self, zc: 'Zeroconf') -> bool: """Populate the service info from the cache.""" now = current_time_millis() + record_updates = [] cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) if cached_srv_record: # If there is a srv record, A and AAAA will already # be called and we do not want to do it twice - self.update_record(zc, now, cached_srv_record) + record_updates.append(cached_srv_record) else: - self._update_addresses_from_cache(zc, now) + record_updates.extend(self._get_address_records_from_cache(zc)) cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: - self.update_record(zc, now, cached_txt_record) + record_updates.append(cached_txt_record) + for record in record_updates: + self.update_record(zc, now, record) return self._is_complete @property From 41de419453c0679c5a04ec248339783afbeb0e4f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 4 Jun 2021 12:32:16 -1000 Subject: [PATCH 149/608] A methods to generate DNSRecords from ServiceInfo (#422) --- zeroconf/__init__.py | 174 ++++++++++++++++--------------------------- 1 file changed, 66 insertions(+), 108 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 09521937..0c06ab90 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1987,6 +1987,54 @@ def _process_record(self, record: DNSRecord, now: float) -> None: if record.key == self.key: self._set_text(record.text) + def dns_addresses( + self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return [ + DNSAddress( + self.server, + _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + address, + ) + for address in self.addresses_by_version(version) + ] + + def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self.name, + ) + + def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return DNSService( + self.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + cast(int, self.port), + self.server, + ) + + def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return DNSText( + self.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + ) + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: """Get the address records from the cache.""" address_records = [] @@ -2704,36 +2752,18 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) - out.add_authorative_answer(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, info.other_ttl, info.name)) + out.add_authorative_answer(info.dns_pointer()) return out def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]) -> None: """Add answers to broadcast a service.""" other_ttl = info.other_ttl if override_ttl is None else override_ttl host_ttl = info.host_ttl if override_ttl is None else override_ttl - out.add_answer_at_time(DNSPointer(info.type, _TYPE_PTR, _CLASS_IN, other_ttl, info.name), 0) - out.add_answer_at_time( - DNSService( - info.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - host_ttl, - info.priority, - info.weight, - cast(int, info.port), - info.server, - ), - 0, - ) - - out.add_answer_at_time( - DNSText(info.name, _TYPE_TXT, _CLASS_IN | _CLASS_UNIQUE, other_ttl, info.text), 0 - ) - for address in info.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer_at_time( - DNSAddress(info.server, type_, _CLASS_IN | _CLASS_UNIQUE, host_ttl, address), 0 - ) + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) + for dns_address in info.dns_addresses(override_ttl=host_ttl, version=IPVersion.All): + out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: """Unregister a service.""" @@ -2926,108 +2956,36 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None for service in self.registry.get_infos_type(question.name): if out is None: out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - out.add_answer( - msg, - DNSPointer(service.type, _TYPE_PTR, _CLASS_IN, service.other_ttl, service.name), - ) - + out.add_answer(msg, service.dns_pointer()) # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer( - DNSService( - service.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, - ) - ) - out.add_additional_answer( - DNSText( - service.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ) - ) - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( - DNSAddress( - service.server, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ) - ) + out.add_additional_answer(service.dns_service()) + out.add_additional_answer(service.dns_text()) + for dns_address in service.dns_addresses(version=IPVersion.All): + out.add_additional_answer(dns_address) + else: if out is None: out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know if question.type in (_TYPE_A, _TYPE_ANY): for service in self.registry.get_infos_server(name_to_find): - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_answer( - msg, - DNSAddress( - question.name, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ), - ) + for dns_address in service.dns_addresses(version=IPVersion.All): + out.add_answer(msg, dns_address) service = self.registry.get_info_name(name_to_find) # type: ignore if service is None: continue if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer( - msg, - DNSService( - question.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - service.priority, - service.weight, - cast(int, service.port), - service.server, - ), - ) + out.add_answer(msg, service.dns_service()) if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer( - msg, - DNSText( - question.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - service.other_ttl, - service.text, - ), - ) + out.add_answer(msg, service.dns_text()) if question.type == _TYPE_SRV: - for address in service.addresses_by_version(IPVersion.All): - type_ = _TYPE_AAAA if _is_v6_address(address) else _TYPE_A - out.add_additional_answer( - DNSAddress( - service.server, - type_, - _CLASS_IN | _CLASS_UNIQUE, - service.host_ttl, - address, - ) - ) + for dns_address in service.dns_addresses(version=IPVersion.All): + out.add_additional_answer(dns_address) if out is not None and out.answers: out.id = msg.id From fc97e5c3ad35da789373a1898c00efe0f13a3b5f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 4 Jun 2021 14:19:58 -1000 Subject: [PATCH 150/608] Remove unused argument from ServiceInfo.dns_addresses (#423) - This should always return all addresses since its _CLASS_UNIQUE --- zeroconf/__init__.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0c06ab90..b0f9bedc 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1987,9 +1987,7 @@ def _process_record(self, record: DNSRecord, now: float) -> None: if record.key == self.key: self._set_text(record.text) - def dns_addresses( - self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All - ) -> List[DNSAddress]: + def dns_addresses(self, override_ttl: Optional[int] = None) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return [ DNSAddress( @@ -1999,7 +1997,7 @@ def dns_addresses( override_ttl if override_ttl is not None else self.host_ttl, address, ) - for address in self.addresses_by_version(version) + for address in self._addresses ] def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: @@ -2762,7 +2760,7 @@ def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_tt out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl, version=IPVersion.All): + for dns_address in info.dns_addresses(override_ttl=host_ttl): out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: @@ -2961,7 +2959,7 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None # https://tools.ietf.org/html/rfc6763#section-12.1. out.add_additional_answer(service.dns_service()) out.add_additional_answer(service.dns_text()) - for dns_address in service.dns_addresses(version=IPVersion.All): + for dns_address in service.dns_addresses(): out.add_additional_answer(dns_address) else: @@ -2972,7 +2970,7 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None # Answer A record queries for any service addresses we know if question.type in (_TYPE_A, _TYPE_ANY): for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(version=IPVersion.All): + for dns_address in service.dns_addresses(): out.add_answer(msg, dns_address) service = self.registry.get_info_name(name_to_find) # type: ignore @@ -2984,7 +2982,7 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None if question.type in (_TYPE_TXT, _TYPE_ANY): out.add_answer(msg, service.dns_text()) if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(version=IPVersion.All): + for dns_address in service.dns_addresses(): out.add_additional_answer(dns_address) if out is not None and out.answers: From 47e266eb66be36b355f1738cd4d2f7369712b7b3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 4 Jun 2021 14:38:10 -1000 Subject: [PATCH 151/608] Avoid checking the registry when answering requests for _services._dns-sd._udp.local. (#425) - _services._dns-sd._udp.local. is a special case and should never be in the registry --- zeroconf/__init__.py | 43 +++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b0f9bedc..37761b76 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2951,6 +2951,8 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None stype, ), ) + continue + for service in self.registry.get_infos_type(question.name): if out is None: out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) @@ -2962,28 +2964,29 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None for dns_address in service.dns_addresses(): out.add_additional_answer(dns_address) - else: - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(name_to_find) # type: ignore - if service is None: - continue + continue - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: + if out is None: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) + out.add_answer(msg, dns_address) + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + continue + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer(msg, service.dns_service()) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer(msg, service.dns_text()) + if question.type == _TYPE_SRV: + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) if out is not None and out.answers: out.id = msg.id From e68e337cd482e06a422b2d2e2e6ae12ce1673ce5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 4 Jun 2021 14:56:50 -1000 Subject: [PATCH 152/608] Remove is_type_unique as it is unused (#426) --- zeroconf/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 37761b76..dc07b936 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -953,10 +953,6 @@ class State(enum.Enum): init = 0 finished = 1 - @staticmethod - def is_type_unique(type_: int) -> bool: - return type_ == _TYPE_TXT or type_ == _TYPE_SRV or type_ == _TYPE_A or type_ == _TYPE_AAAA - def add_question(self, record: DNSQuestion) -> None: """Adds a question""" self.questions.append(record) From e7b2bb5e351f04f4f1e14ef5a20ed2111f8097c4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 08:44:34 -1000 Subject: [PATCH 153/608] Seperate non-thread specific code from ServiceBrowser into _ServiceBrowserBase (#428) --- zeroconf/__init__.py | 59 ++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dc07b936..234106fb 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1553,13 +1553,8 @@ def notify_all(self) -> None: raise NotImplementedError() -class ServiceBrowser(RecordUpdateListener, threading.Thread): - - """Used to browse for a service of a specific type. - - The listener object will have its add_service() and - remove_service() methods called when this browser - discovers changes in the services availability.""" +class _ServiceBrowserBase(RecordUpdateListener): + """Base class for ServiceBrowser.""" def __init__( self, @@ -1577,7 +1572,6 @@ def __init__( for check_type_ in self.types: if not check_type_.endswith(service_type_name(check_type_, strict=False)): raise BadTypeInNameException - threading.Thread.__init__(self) self.daemon = True self.zc = zc self.addr = addr @@ -1629,12 +1623,6 @@ def on_change( for h in handlers: self.service_state_changed.register_handler(h) - self.start() - self.name = "zeroconf-ServiceBrowser-%s-%s" % ( - '-'.join([type_[:-7] for type_ in self.types]), - getattr(self, 'native_id', self.ident), - ) - @property def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface @@ -1715,9 +1703,14 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> enqueue_callback(ServiceStateChange.Updated, type_, record.name) def cancel(self) -> None: + """Cancel the browser.""" self.done = True self.zc.remove_listener(self) - self.join() + + def run(self) -> None: + """Run the browser.""" + questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] + self.zc.add_listener(self, questions) def generate_ready_queries(self) -> Optional[DNSOutgoing]: """Generate the service browser query for any type that is due.""" @@ -1743,10 +1736,40 @@ def generate_ready_queries(self) -> Optional[DNSOutgoing]: self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) return out - def run(self) -> None: - questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] - self.zc.add_listener(self, questions) +class ServiceBrowser(_ServiceBrowserBase, threading.Thread): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + threading.Thread.__init__(self) + super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + self.start() + self.name = "zeroconf-ServiceBrowser-%s-%s" % ( + '-'.join([type_[:-7] for type_ in self.types]), + getattr(self, 'native_id', self.ident), + ) + + def cancel(self) -> None: + """Cancel the browser.""" + super().cancel() + self.join() + + def run(self) -> None: + """Run the browser thread.""" + super().run() while True: if not self._handlers_to_call: # Wait for the type has the smallest next time From 415a7b762030e9d236bef71f39156686a0b277f9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 09:51:41 -1000 Subject: [PATCH 154/608] Implement an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) --- zeroconf/asyncio.py | 89 +++++++++++++++++++++++++++++++++++++++- zeroconf/test_asyncio.py | 66 ++++++++++++++++++++++++++++- 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index faddb67a..200b652d 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -23,7 +23,7 @@ import contextlib import queue import threading -from typing import Awaitable, Optional +from typing import Awaitable, Callable, Dict, List, Optional, Union from . import ( DNSOutgoing, @@ -34,10 +34,12 @@ NotifyListener, ServiceInfo, Zeroconf, + _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, + _ServiceBrowserBase, _UNREGISTER_TIME, current_time_millis, instance_name_from_service_info, @@ -97,6 +99,17 @@ def async_notify_all(self) -> None: self.event.clear() +class AsyncServiceListener: + def add_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def remove_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def update_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + class AsyncServiceInfo(ServiceInfo): """An async version of ServiceInfo.""" @@ -132,6 +145,59 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: return True +class AsyncServiceBrowser(_ServiceBrowserBase): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + aiozc: 'AsyncZeroconf', + type_: Union[str, list], + handlers: Optional[Union[AsyncServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[AsyncServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + self.aiozc = aiozc + super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore + self._browser_task = asyncio.ensure_future(self.async_run()) + + def cancel(self) -> None: + """Cancel the browser.""" + super().cancel() + self._browser_task.cancel() + + async def async_run(self) -> None: + """Run the browser task.""" + self.run() + while True: + if not self._handlers_to_call: + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + if next_time > now: + await self.aiozc.async_wait(next_time - now) + + out = self.generate_ready_queries() + if out: + self.aiozc.sender.send(out, addr=self.addr, port=self.port) + + if not self._handlers_to_call: + continue + + (name_type, state_change) = self._handlers_to_call.popitem(False) + self._service_state_changed.fire( + zeroconf=self.aiozc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + class AsyncZeroconf: """Implementation of Zeroconf Multicast DNS Service Discovery @@ -173,6 +239,7 @@ def __init__( self.loop = asyncio.get_event_loop() self.async_notify = AsyncNotifyListener() self.zeroconf.add_notify_listener(self.async_notify) + self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} self.sender = _AsyncSender(self.zeroconf) async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: @@ -247,6 +314,7 @@ def _close(self) -> None: async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" + await self.async_remove_all_service_listeners() await self.loop.run_in_executor(None, self._close) async def async_get_service_info( @@ -264,3 +332,22 @@ async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" with contextlib.suppress(asyncio.TimeoutError): await asyncio.wait_for(self.async_notify.event.wait(), timeout / 1000) + + async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + await self.async_remove_service_listener(listener) + self.async_browsers[listener] = AsyncServiceBrowser(self, type_, listener) + + async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.async_browsers: + self.async_browsers[listener].cancel() + del self.async_browsers[listener] + + async def async_remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + await asyncio.gather( + *[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)] + ) diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index eebed600..8126b69c 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -20,7 +20,7 @@ _LISTENER_TIME, current_time_millis, ) -from .asyncio import AsyncServiceInfo, AsyncZeroconf +from .asyncio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf @pytest.mark.asyncio @@ -374,3 +374,67 @@ async def test_service_info_async_request() -> None: assert aiosinfo is None await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_browser() -> None: + """Test AsyncServiceBrowser.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + with pytest.raises(NotImplementedError): + AsyncServiceListener().add_service(aiozc, "_type", "name._type") + + with pytest.raises(NotImplementedError): + AsyncServiceListener().remove_service(aiozc, "_type", "name._type") + + with pytest.raises(NotImplementedError): + AsyncServiceListener().update_service(aiozc, "_type", "name._type") + + class MyListener(AsyncServiceListener): + def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + await aiozc.async_add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task + await aiozc.async_close() + + assert calls[0] == ('add', type_, registration_name) From e5a0c9a45df93a668f3611ddf5c41a1800cb4556 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 10:18:00 -1000 Subject: [PATCH 155/608] Fix warning when generating sphinx docs (#432) - `docstring of zeroconf.ServiceInfo:5: WARNING: Unknown target name: "type".` --- zeroconf/__init__.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 234106fb..fc90a3c2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1804,19 +1804,19 @@ class ServiceInfo(RecordUpdateListener): Constructor parameters are as follows: - * type_: fully qualified service type name - * name: fully qualified service name - * port: port that the service runs on - * weight: weight of the service - * priority: priority of the service - * properties: dictionary of properties (or a bytes object holding the contents of the `text` field). + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to value-less attributes. - * server: fully qualified name for service host (defaults to name) - * host_ttl: ttl used for A/SRV records - * other_ttl: ttl used for PTR/TXT records - * addresses and parsed_addresses: List of IP addresses (either as bytes, network byte order, or in parsed - form as text; at most one of those parameters can be provided) + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) """ From 5460caef83b5cdb9c5d637741ed95dea6b328f08 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 10:20:58 -1000 Subject: [PATCH 156/608] Add zeroconf.asyncio to the docs (#434) --- docs/api.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index 5bd2508f..1704db5a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -5,3 +5,8 @@ python-zeroconf API reference :members: :undoc-members: :show-inheritance: + +.. automodule:: zeroconf.asyncio + :members: + :undoc-members: + :show-inheritance: From 6737e13d8e6227b96d5cc0e776c62889b7dc4fd3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 10:25:24 -1000 Subject: [PATCH 157/608] Update changelog for latest changes (#435) --- README.rst | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/README.rst b/README.rst index 6a7e04f5..10472ba6 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,33 @@ Changelog 0.32.0 (Unreleased) =================== +* Add zeroconf.asyncio to the docs (#434) @bdraco + +* Fix warning when generating sphinx docs (#432) @bdraco + +* Implement an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco + +* Seperate non-thread specific code from ServiceBrowser into _ServiceBrowserBase (#428) @bdraco + +* Remove is_type_unique as it is unused (#426) + +* Avoid checking the registry when answering requests for _services._dns-sd._udp.local. (#425) @bdraco + + _services._dns-sd._udp.local. is a special case and should never + be in the registry + +* Remove unused argument from ServiceInfo.dns_addresses (#423) @bdraco + +* Add methods to generate DNSRecords from ServiceInfo (#422) @bdraco + +* Seperate logic for consuming records in ServiceInfo (#421) @bdraco + +* Seperate query generation for ServiceBrowser (#420) @bdraco + +* Add async_request example with browse (#415) @bdraco + +* Add async_register_service/async_unregister_service example (#414) @bdraco + * Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco * Add support for registering notify listeners (#409) @bdraco From 1d3f986e00e18682c209cecbdea2481f4ca987b5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 10:52:48 -1000 Subject: [PATCH 158/608] Cleanup unnecessary else after returns (#436) --- pyproject.toml | 31 +++++++++++++++++++++++++++++++ zeroconf/__init__.py | 28 ++++++++++++---------------- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b48e90ee..55b56091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,3 +2,34 @@ line-length = 110 target_version = ['py35', 'py36', 'py37', 'py38'] skip_string_normalization = true + +[tool.pylint.BASIC] +class-const-naming-style = "any" +good-names = [ + "e", + "er", + "h", + "i", + "id", + "ip", + "os", + "n", + "rr", + "rs", + "s", + "t", + "wr", + "zc", + "_GLOBAL_DONE", +] + +[tool.pylint."MESSAGES CONTROL"] +disable = [ + "fixme", + "format", + "missing-class-docstring", + "missing-function-docstring", + "too-few-public-methods", + "too-many-arguments", + "too-many-instance-attributes" +] diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index fc90a3c2..41c424cd 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -678,8 +678,7 @@ def __repr__(self) -> str: """String representation""" if len(self.text) > 10: return self.to_string(self.text[:7]) + "..." - else: - return self.to_string(self.text) + return self.to_string(self.text) class DNSService(DNSRecord): @@ -1182,14 +1181,13 @@ def packet(self) -> bytes: does not fit in a single packet, but this exists for backward compatibility.""" packets = self.packets() - if len(packets) > 0: - if len(packets[0]) > _MAX_MSG_ABSOLUTE: - QuietLogger.log_warning_once( - "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] - ) - return packets[0] - else: + if len(packets) == 0: return b'' + if len(packets[0]) > _MAX_MSG_ABSOLUTE: + QuietLogger.log_warning_once( + "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] + ) + return packets[0] def packets(self) -> List[bytes]: """Returns a list of bytestrings containing the packets' bytes @@ -1904,10 +1902,9 @@ def addresses_by_version(self, version: IPVersion) -> List[bytes]: """List addresses matching IP version.""" if version == IPVersion.V4Only: return [addr for addr in self._addresses if not _is_v6_address(addr)] - elif version == IPVersion.V6Only: + if version == IPVersion.V6Only: return list(filter(_is_v6_address, self._addresses)) - else: - return self._addresses + return self._addresses def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: """List addresses in their parsed string form.""" @@ -2394,18 +2391,17 @@ def add_multicast_member( interface, ) return False - elif _errno == errno.EADDRNOTAVAIL: + if _errno == errno.EADDRNOTAVAIL: log.info( 'Address not available when adding %s to multicast ' 'group, it is expected to happen on some systems', interface, ) return False - elif _errno in err_einval: + if _errno in err_einval: log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) return False - else: - raise + raise return True From 8412eb791dd5ad1c287c1d7cc24c5db75a5291b7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 10:56:10 -1000 Subject: [PATCH 159/608] Cleanup unused variables (#437) --- zeroconf/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 41c424cd..398f019f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -789,7 +789,7 @@ def read_header(self) -> None: def read_questions(self) -> None: """Reads questions section of packet""" - for i in range(self.num_questions): + for _ in range(self.num_questions): name = self.read_name() type_, class_ = self.unpack(b'!HH') @@ -820,7 +820,7 @@ def read_others(self) -> None: """Reads the answers, authorities and additionals section of the packet""" n = self.num_answers + self.num_authorities + self.num_additionals - for i in range(n): + for _ in range(n): domain = self.read_name() type_, class_, ttl, length = self.unpack(b'!HHiH') @@ -1381,7 +1381,7 @@ def run(self) -> None: try: rs.append(self.socketpair[0]) - rr, wr, er = select.select(rs, [], [], self.timeout) + rr, _wr, _er = select.select(rs, [], [], self.timeout) if self.zc.done: return From 4bcb698bda0ec7266d5e454b5e81a07eb64be32a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 11:13:05 -1000 Subject: [PATCH 160/608] Disable pylint too-many-branches for functions that need refactoring (#439) --- zeroconf/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 398f019f..c9d6b1c5 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -232,7 +232,7 @@ def _encode_address(address: str) -> bytes: return socket.inet_pton(address_family, address) -def service_type_name(type_: str, *, strict: bool = True) -> str: +def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -2867,7 +2867,7 @@ def update_record(self, now: float, rec: DNSRecord) -> None: listener.update_record(self, now, rec) self.notify_all() - def handle_response(self, msg: DNSIncoming) -> None: + def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many-branches """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" updates = [] # type: List[DNSRecord] @@ -2938,7 +2938,9 @@ def handle_response(self, msg: DNSIncoming) -> None: for record in removes: self.cache.remove(record) - def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: + def handle_query( + self, msg: DNSIncoming, addr: Optional[str], port: int + ) -> None: # pylint: disable=too-many-branches """Deal with incoming query packets. Provides a response if possible.""" out = None From 594da709273c2e0a53fee2f9ad7fcec607ad0868 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 11:16:59 -1000 Subject: [PATCH 161/608] Remove unused now argument from ServiceInfo._process_record (#440) --- zeroconf/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index c9d6b1c5..0779d878 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1973,16 +1973,16 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) if record is None or record.is_expired(now): return - self._process_record(record, now) + self._process_record(record) # Only update addresses if the DNSService (.server) has changed if not isinstance(record, DNSService): return for record in self._get_address_records_from_cache(zc): - self._process_record(record, now) + self._process_record(record) - def _process_record(self, record: DNSRecord, now: float) -> None: + def _process_record(self, record: DNSRecord) -> None: if isinstance(record, DNSAddress): if record.key == self.server_key and record.address not in self._addresses: self._addresses.append(record.address) From a70370a0f653df911cc6f641522cec0fcc8471a3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 11:19:54 -1000 Subject: [PATCH 162/608] Convert unnecessary use of a comprehension to a list (#441) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0779d878..33129863 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2699,7 +2699,7 @@ def remove_service_listener(self, listener: ServiceListener) -> None: def remove_all_service_listeners(self) -> None: """Removes a listener from the set that is currently listening.""" - for listener in [k for k in self.browsers]: + for listener in list(self.browsers): self.remove_service_listener(listener) def register_service( From 41be4f4db0501adb9fbaa6b353fbcb36a45e6e21 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 11:26:06 -1000 Subject: [PATCH 163/608] Merge _TYPE_CNAME and _TYPE_PTR comparison in DNSIncoming.read_others (#442) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 33129863..a22d84b1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -827,7 +827,7 @@ def read_others(self) -> None: rec = None # type: Optional[DNSRecord] if type_ == _TYPE_A: rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) - elif type_ == _TYPE_CNAME or type_ == _TYPE_PTR: + elif type_ in (_TYPE_CNAME, _TYPE_PTR): rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) elif type_ == _TYPE_TXT: rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) From 6002c9c88a9a49814f86070c07925f798a61461a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 11:38:05 -1000 Subject: [PATCH 164/608] Disable broad except checks in places we still catch broad exceptions (#443) --- zeroconf/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a22d84b1..a0301f4f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -592,7 +592,7 @@ def __repr__(self) -> str: socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address ) ) - except Exception: # TODO stop catching all Exceptions + except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions return self.to_string(str(self.address)) @@ -1446,7 +1446,7 @@ def __init__(self, zc: 'Zeroconf') -> None: def handle_read(self, socket_: socket.socket) -> None: try: data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) - except Exception: + except Exception: # pylint: disable=broad-except self.log_exception_warning('Error reading from socket %d', socket_.fileno()) return @@ -2857,7 +2857,7 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: try: self.listeners.remove(listener) self.notify_all() - except Exception as e: # TODO stop catching all Exceptions + except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions log.exception('Unknown error, possibly benign: %r', e) def update_record(self, now: float, rec: DNSRecord) -> None: @@ -3030,7 +3030,7 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P else: real_addr = addr bytes_sent = s.sendto(packet, 0, (real_addr, port)) - except Exception as exc: # TODO stop catching all Exceptions + except Exception as exc: # pylint: disable=broad-except # TODO stop catching all Exceptions if ( isinstance(exc, OSError) and exc.errno == errno.ENETUNREACH From 424c00257083f1d091a52ff0c966b306eea70efb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:11:05 -1000 Subject: [PATCH 165/608] Remove unneeded-not in new_socket (#445) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a0301f4f..9c154a89 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2333,7 +2333,7 @@ def new_socket( try: s.setsockopt(socket.SOL_SOCKET, reuseport, 1) except OSError as err: - if not err.errno == errno.ENOPROTOOPT: + if err.errno != errno.ENOPROTOOPT: raise if port == _MDNS_PORT: From 929ba12d046496782491d96160e6cb8d0d04cfe5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:14:09 -1000 Subject: [PATCH 166/608] Fix redefining argument with the local name 'record' in ServiceInfo.update_record (#448) --- zeroconf/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9c154a89..4dabdf56 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1979,8 +1979,8 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) if not isinstance(record, DNSService): return - for record in self._get_address_records_from_cache(zc): - self._process_record(record) + for cached_record in self._get_address_records_from_cache(zc): + self._process_record(cached_record) def _process_record(self, record: DNSRecord) -> None: if isinstance(record, DNSAddress): From ffc6cbb94d7401a70ebd6f747ed6c5e56e528bb0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:18:10 -1000 Subject: [PATCH 167/608] Add missing update_service method to ZeroconfServiceTypes (#449) --- zeroconf/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 4dabdf56..29aca513 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2157,13 +2157,18 @@ class ZeroconfServiceTypes(ServiceListener): """ def __init__(self) -> None: + """Keep track of found services in a set.""" self.found_services = set() # type: Set[str] def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + """Service added.""" self.found_services.add(name) + def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + """Service updated.""" + def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - pass + """Service removed.""" @classmethod def find( From 18851ed4c0f605996798472e1a68dded16d41ff6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:26:23 -1000 Subject: [PATCH 168/608] Extract _get_queue from _AsyncSender (#444) --- zeroconf/asyncio.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 200b652d..89dff47e 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -46,6 +46,13 @@ ) +def _get_best_available_queue() -> queue.Queue: + """Create the best available queue type.""" + if hasattr(queue, "SimpleQueue"): + return queue.SimpleQueue() # type: ignore + return queue.Queue() + + class _AsyncSender(threading.Thread): """A thread to handle sending DNSOutgoing for asyncio.""" @@ -53,16 +60,10 @@ def __init__(self, zc: 'Zeroconf'): """Create the sender thread.""" super().__init__() self.zc = zc - self.queue = self._get_queue() + self.queue = _get_best_available_queue() self.start() self.name = "AsyncZeroconfSender" - def _get_queue(self) -> queue.Queue: - """Create the best available queue type.""" - if hasattr(queue, "SimpleQueue"): - return queue.SimpleQueue() # type: ignore - return queue.Queue() - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: """Queue a send to be processed by the thread.""" self.queue.put((out, addr, port)) From 7e03f836dd7a4ee938bfff21cd150e863f608b5e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:35:15 -1000 Subject: [PATCH 169/608] Mark methods used by asyncio without self use (#447) --- zeroconf/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 29aca513..27e19941 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2766,14 +2766,16 @@ def send_service_query(self, info: ServiceInfo) -> None: """Send a query to lookup a service.""" self.send(self.generate_service_query(info)) - def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: + def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) out.add_authorative_answer(info.dns_pointer()) return out - def _add_broadcast_answer(self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int]) -> None: + def _add_broadcast_answer( # pylint: disable=no-self-use + self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] + ) -> None: """Add answers to broadcast a service.""" other_ttl = info.other_ttl if override_ttl is None else override_ttl host_ttl = info.host_ttl if override_ttl is None else override_ttl From ef0cf8e393a8ffdccb3cd2094a8764f707f518c1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:47:12 -1000 Subject: [PATCH 170/608] Disable no-member check for WSAEINVAL false positive (#454) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 27e19941..988a6ca3 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2377,7 +2377,7 @@ def add_multicast_member( err_einval = {errno.EINVAL} if sys.platform == 'win32': # No WSAEINVAL definition in typeshed - err_einval |= {cast(Any, errno).WSAEINVAL} + err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: From f26a92bc2abe61f5a2b5acd76991f81d07452201 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:52:30 -1000 Subject: [PATCH 171/608] Use unique name in test_async_service_browser test (#450) --- zeroconf/test_asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index 8126b69c..b436932a 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -380,7 +380,7 @@ async def test_service_info_async_request() -> None: async def test_async_service_browser() -> None: """Test AsyncServiceBrowser.""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test1-srvc-type._tcp.local." + type_ = "_test9-srvc-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) From 7544cdf956c4eeb4b688729432ba87278f606b7c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 12:55:27 -1000 Subject: [PATCH 172/608] Disable pylint no-self-use check on abstract methods (#451) --- zeroconf/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 988a6ca3..a1a9905a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -505,7 +505,7 @@ def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException @@ -552,7 +552,7 @@ def reset_ttl(self, other: 'DNSRecord') -> None: self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - def write(self, out: 'DNSOutgoing') -> None: + def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException From 5fce89db2707b163231aec216e4c4fc310527e4c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 13:02:31 -1000 Subject: [PATCH 173/608] Mark functions with too many branches in need of refactoring (#455) --- zeroconf/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a1a9905a..00dd4666 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2295,7 +2295,7 @@ def normalize_interface_choice( return result -def new_socket( +def new_socket( # pylint: disable=too-many-branches bind_addr: Union[Tuple[str], Tuple[str, int, int]], port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, @@ -2945,9 +2945,9 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many for record in removes: self.cache.remove(record) - def handle_query( + def handle_query( # pylint: disable=too-many-branches self, msg: DNSIncoming, addr: Optional[str], port: int - ) -> None: # pylint: disable=too-many-branches + ) -> None: """Deal with incoming query packets. Provides a response if possible.""" out = None From 69c4cf69bbc34474e70eac3ad0fe905be7ab4eb4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 13:21:36 -1000 Subject: [PATCH 174/608] Disable protected-access on the ServiceBrowser usage of _handlers_lock (#452) - This will be fixed in https://github.com/jstasiak/python-zeroconf/pull/419 --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 00dd4666..a3af82f0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1786,7 +1786,7 @@ def run(self) -> None: if not self._handlers_to_call: continue - with self.zc._handlers_lock: + with self.zc._handlers_lock: # pylint: disable=protected-access (name_type, state_change) = self._handlers_to_call.popitem(False) self._service_state_changed.fire( From 9510808cfd334b0b2f6381da8214225c4cfbf6a0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 13:42:35 -1000 Subject: [PATCH 175/608] Trap OSError directly in Zeroconf.send instead of checking isinstance (#453) - Fixes: Instance of 'Exception' has no 'errno' member (no-member) --- zeroconf/__init__.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a3af82f0..0356b547 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -3037,17 +3037,16 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P else: real_addr = addr bytes_sent = s.sendto(packet, 0, (real_addr, port)) - except Exception as exc: # pylint: disable=broad-except # TODO stop catching all Exceptions - if ( - isinstance(exc, OSError) - and exc.errno == errno.ENETUNREACH - and s.family == socket.AF_INET6 - ): + except OSError as exc: + if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: # with IPv6 we don't have a reliable way to determine if an interface actually has # IPV6 support, so we have to try and ignore errors. continue # on send errors, log the exception and keep going self.log_exception_warning('Error sending through socket %d', s.fileno()) + except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions + # on send errors, log the exception and keep going + self.log_exception_warning('Error sending through socket %d', s.fileno()) else: if bytes_sent != len(packet): self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) From 6fafdee241571d68937e29ee0a2b1bd5ef0038d9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 14:30:34 -1000 Subject: [PATCH 176/608] Enable pylint (#438) --- Makefile | 5 ++++- pyproject.toml | 4 +++- requirements-dev.txt | 1 + zeroconf/asyncio.py | 2 +- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index fed4d9b1..37d4bff6 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ PYTHON_VERSION:=$(shell python -c "import sys;sys.stdout.write('%d.%d' % sys.ver LINT_TARGETS:=flake8 ifneq ($(findstring PyPy,$(PYTHON_IMPLEMENTATION)),PyPy) - LINT_TARGETS:=$(LINT_TARGETS) mypy black_check + LINT_TARGETS:=$(LINT_TARGETS) mypy black_check pylint endif @@ -28,6 +28,9 @@ lint: $(LINT_TARGETS) flake8: flake8 --max-line-length=$(MAX_LINE_LENGTH) setup.py examples zeroconf +pylint: + pylint zeroconf/__init__.py zeroconf/asyncio.py + .PHONY: black_check black_check: black --check setup.py examples zeroconf diff --git a/pyproject.toml b/pyproject.toml index 55b56091..cd79b3e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,11 +25,13 @@ good-names = [ [tool.pylint."MESSAGES CONTROL"] disable = [ + "duplicate-code", "fixme", "format", "missing-class-docstring", "missing-function-docstring", "too-few-public-methods", "too-many-arguments", - "too-many-instance-attributes" + "too-many-instance-attributes", + "too-many-public-methods" ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 30b906e4..325ce32a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,6 +9,7 @@ ifaddr mypy;implementation_name=="cpython" # 0.11.0 breaks things https://github.com/PyCQA/pep8-naming/issues/152 pep8-naming!=0.6.0,!=0.11.0 +pylint pytest pytest-asyncio pytest-cov diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 89dff47e..69f7e8e3 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -49,7 +49,7 @@ def _get_best_available_queue() -> queue.Queue: """Create the best available queue type.""" if hasattr(queue, "SimpleQueue"): - return queue.SimpleQueue() # type: ignore + return queue.SimpleQueue() # type: ignore # pylint: disable=all return queue.Queue() From 5e24da08bc463bf79b27eb3768ec01755804f403 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 14:45:31 -1000 Subject: [PATCH 177/608] Reduce branching in Zeroconf.handle_query (#460) --- zeroconf/__init__.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0356b547..60dc6cc7 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2950,21 +2950,19 @@ def handle_query( # pylint: disable=too-many-branches ) -> None: """Deal with incoming query packets. Provides a response if possible.""" - out = None - # Support unicast client responses # if port != _MDNS_PORT: out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) for question in msg.questions: out.add_question(question) + else: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) for question in msg.questions: if question.type == _TYPE_PTR: if question.name == "_services._dns-sd._udp.local.": for stype in self.registry.get_types(): - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) out.add_answer( msg, DNSPointer( @@ -2978,8 +2976,6 @@ def handle_query( # pylint: disable=too-many-branches continue for service in self.registry.get_infos_type(question.name): - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) out.add_answer(msg, service.dns_pointer()) # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. @@ -2990,9 +2986,6 @@ def handle_query( # pylint: disable=too-many-branches continue - if out is None: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - name_to_find = question.name.lower() # Answer A record queries for any service addresses we know if question.type in (_TYPE_A, _TYPE_ANY): From ceb0def1b43f2e55bb17e33d13d4efdaa055221c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 14:57:24 -1000 Subject: [PATCH 178/608] Reduce branching in Zeroconf.handle_response (#459) --- zeroconf/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 60dc6cc7..ca284849 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -23,6 +23,7 @@ import enum import errno import ipaddress +import itertools import logging import platform import re @@ -2935,9 +2936,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many # zc.get_service_info will see the cached value # but ONLY after all the record updates have been # processsed. - for record in address_adds: - self.cache.add(record) - for record in other_adds: + for record in itertools.chain(address_adds, other_adds): self.cache.add(record) # Removes are processed last since # ServiceInfo could generate an un-needed query From 558cec3687ac7e7f494ab7aa4ce574c1e784b81f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 15:27:42 -1000 Subject: [PATCH 179/608] Use constant for service type enumeration (#461) --- zeroconf/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ca284849..378e0be8 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -183,6 +183,9 @@ _TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' _NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' +# https://datatracker.ietf.org/doc/html/rfc6763#section-9 +_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local." + try: _IPPROTO_IPV6 = socket.IPPROTO_IPV6 except AttributeError: @@ -2191,7 +2194,7 @@ def find( """ local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) listener = cls() - browser = ServiceBrowser(local_zc, '_services._dns-sd._udp.local.', listener=listener) + browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener) # wait for responses time.sleep(timeout) @@ -2960,12 +2963,12 @@ def handle_query( # pylint: disable=too-many-branches for question in msg.questions: if question.type == _TYPE_PTR: - if question.name == "_services._dns-sd._udp.local.": + if question.name == _SERVICE_TYPE_ENUMERATION_NAME: for stype in self.registry.get_types(): out.add_answer( msg, DNSPointer( - "_services._dns-sd._udp.local.", + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, From 4c4b529c841f015108a7489bd8f3b92a5e57e827 Mon Sep 17 00:00:00 2001 From: Stepan Henek Date: Sun, 6 Jun 2021 04:08:06 +0200 Subject: [PATCH 180/608] Support for context managers in Zeroconf and AsyncZeroconf (#284) Co-authored-by: J. Nick Koston --- zeroconf/__init__.py | 16 ++++++++++++++-- zeroconf/asyncio.py | 15 ++++++++++++++- zeroconf/test.py | 9 +++++++++ zeroconf/test_asyncio.py | 24 ++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 378e0be8..3458ae79 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -35,7 +35,8 @@ import time import warnings from collections import OrderedDict -from typing import Dict, Iterable, List, Optional, Union, cast +from types import TracebackType # noqa # used in type hints +from typing import Dict, Iterable, List, Optional, Type, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints import ifaddr @@ -3064,8 +3065,19 @@ def close(self) -> None: for s in self._respond_sockets: self.engine.del_reader(s) self.engine.join() - # shutdown the rest self.notify_all() for s in self._respond_sockets: s.close() + + def __enter__(self) -> 'Zeroconf': + return self + + def __exit__( # pylint: disable=useless-return + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self.close() + return None diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 69f7e8e3..65771e8a 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -23,7 +23,8 @@ import contextlib import queue import threading -from typing import Awaitable, Callable, Dict, List, Optional, Union +from types import TracebackType # noqa # used in type hints +from typing import Awaitable, Callable, Dict, List, Optional, Type, Union from . import ( DNSOutgoing, @@ -352,3 +353,15 @@ async def async_remove_all_service_listeners(self) -> None: await asyncio.gather( *[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)] ) + + async def __aenter__(self) -> 'AsyncZeroconf': + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + await self.async_close() + return None diff --git a/zeroconf/test.py b/zeroconf/test.py index dcfa7c28..37a3a246 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -564,6 +564,15 @@ def test_launch_and_close(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) rv.close() + def test_launch_and_close_context_manager(self): + with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv: + assert rv.done is False + assert rv.done is True + + with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv: + assert rv.done is False + assert rv.done is True + def test_launch_and_close_unicast(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True) rv.close() diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index b436932a..e6bcf5a2 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -438,3 +438,27 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: await aiozc.async_close() assert calls[0] == ('add', type_, registration_name) + + +@pytest.mark.asyncio +async def test_async_context_manager() -> None: + """Test using an async context manager.""" + type_ = "_test10-sr-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc: + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None From c1ed987ede34b0049e6466e673b1629d7cd0cd6a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 16:08:27 -1000 Subject: [PATCH 181/608] Break apart Zeroconf.handle_query to reduce branching (#462) --- zeroconf/__init__.py | 100 ++++++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 43 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3458ae79..cb010008 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2948,9 +2948,59 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many for record in removes: self.cache.remove(record) - def handle_query( # pylint: disable=too-many-branches - self, msg: DNSIncoming, addr: Optional[str], port: int - ) -> None: + def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in self.registry.get_types(): + out.add_answer( + msg, + DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, + ), + ) + + def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a PTR query.""" + for service in self.registry.get_infos_type(question.name): + out.add_answer(msg, service.dns_pointer()) + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + out.add_additional_answer(service.dns_service()) + out.add_additional_answer(service.dns_text()) + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a query any query other then PTR. + + Add answer(s) for A, AAAA, SRV, or TXT queries. + """ + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): + for dns_address in service.dns_addresses(): + out.add_answer(msg, dns_address) + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + return + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer(msg, service.dns_service()) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer(msg, service.dns_text()) + if question.type == _TYPE_SRV: + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if possible.""" # Support unicast client responses @@ -2965,48 +3015,12 @@ def handle_query( # pylint: disable=too-many-branches for question in msg.questions: if question.type == _TYPE_PTR: if question.name == _SERVICE_TYPE_ENUMERATION_NAME: - for stype in self.registry.get_types(): - out.add_answer( - msg, - DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - continue - - for service in self.registry.get_infos_type(question.name): - out.add_answer(msg, service.dns_pointer()) - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer(service.dns_service()) - out.add_additional_answer(service.dns_text()) - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - continue - - name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(name_to_find) # type: ignore - if service is None: + self._answer_service_type_enumeration_query(msg, out) + else: + self._answer_ptr_query(msg, out, question) continue - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) + self._answer_non_ptr_query(msg, out, question) if out is not None and out.answers: out.id = msg.id From c3365e1fd060cebc63cc42443260bd785077c246 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 16:50:32 -1000 Subject: [PATCH 182/608] Clear cache between ServiceTypesQuery tests (#466) - Ensures the test relies on the ZeroconfServiceTypes.find making the correct calls instead of the cache from the previous call --- zeroconf/test.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 37a3a246..11131964 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -82,6 +82,12 @@ def has_working_ipv6(): return False +def _clear_cache(zc): + for name in zc.cache.names(): + for record in zc.cache.entries_with_name(name): + zc.cache.remove(record) + + class TestDunder(unittest.TestCase): def test_dns_text_repr(self): # There was an issue on Python 3 that prevented DNSText's repr @@ -1120,6 +1126,7 @@ def test_integration_with_listener(self): try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types + _clear_cache(zeroconf_registrar) service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) assert type_ in service_types @@ -1152,6 +1159,7 @@ def test_integration_with_listener_v6_records(self): try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types + _clear_cache(zeroconf_registrar) service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) assert type_ in service_types @@ -1183,6 +1191,7 @@ def test_integration_with_listener_ipv6(self): try: service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) assert type_ in service_types + _clear_cache(zeroconf_registrar) service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) assert type_ in service_types @@ -1214,6 +1223,7 @@ def test_integration_with_subtype_and_listener(self): try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert discovery_type in service_types + _clear_cache(zeroconf_registrar) service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) assert discovery_type in service_types @@ -1290,9 +1300,7 @@ def update_service(self, zeroconf, type, name): time.sleep(3) # clear the answer cache to force query - for name in zeroconf_browser.cache.names(): - for record in zeroconf_browser.cache.entries_with_name(name): - zeroconf_browser.cache.remove(record) + _clear_cache(zeroconf_browser) cached_info = ServiceInfo(type_, registration_name) cached_info.load_from_cache(zeroconf_browser) From 7a5040247cbaad6bed3fc1204820dfc31ed9b0ae Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 17:04:15 -1000 Subject: [PATCH 183/608] Ensure PTR questions asked in uppercase are answered (#465) --- zeroconf/__init__.py | 4 ++-- zeroconf/test.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index cb010008..7ef3a69a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2967,7 +2967,7 @@ def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgo def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name): + for service in self.registry.get_infos_type(question.name.lower()): out.add_answer(msg, service.dns_pointer()) # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. @@ -3014,7 +3014,7 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None for question in msg.questions: if question.type == _TYPE_PTR: - if question.name == _SERVICE_TYPE_ENUMERATION_NAME: + if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: self._answer_service_type_enumeration_query(msg, out) else: self._answer_ptr_query(msg, out, question) diff --git a/zeroconf/test.py b/zeroconf/test.py index 11131964..d471db3e 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1004,6 +1004,37 @@ def test_name_conflicts(self): zc.register_service(conflicting_info) zc.close() + def test_register_and_lookup_type_by_uppercase_name(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_mylowertype._tcp.local." + name = "Home" + registration_name = "%s.%s" % (name, type_) + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + _clear_cache(zc) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [] + + out = r.DNSOutgoing(r._FLAGS_QR_QUERY) + out.add_question(r.DNSQuestion(type_.upper(), r._TYPE_PTR, r._CLASS_IN)) + zc.send(out) + time.sleep(0.5) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] + assert info.properties == {b"version": b"1.0"} + zc.close() + class TestServiceRegistry(unittest.TestCase): def test_only_register_once(self): From 8a9ae29b6f6643f3625938ac44df66dcc556de46 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 22:10:12 -1000 Subject: [PATCH 184/608] Reduce branching in Zeroconf.handle_response (#467) - Adds `add_records` and `remove_records` to `DNSCache` to permit multiple records to be added or removed in one call - This change is not enough to remove the too-many-branches pylint disable, however when combined with #419 it should no longer be needed --- zeroconf/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 7ef3a69a..179e4221 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1284,12 +1284,22 @@ def add(self, entry: DNSRecord) -> None: if isinstance(entry, DNSService): self.service_cache.setdefault(entry.server, []).append(entry) + def add_records(self, entries: Iterable[DNSRecord]) -> None: + """Add multiple records.""" + for entry in entries: + self.add(entry) + def remove(self, entry: DNSRecord) -> None: """Removes an entry.""" if isinstance(entry, DNSService): DNSCache.remove_key(self.service_cache, entry.server, entry) DNSCache.remove_key(self.cache, entry.key, entry) + def remove_records(self, entries: Iterable[DNSRecord]) -> None: + """Remove multiple records.""" + for entry in entries: + self.remove(entry) + @staticmethod def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: """Forgiving remove of a cache key.""" @@ -2940,13 +2950,11 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many # zc.get_service_info will see the cached value # but ONLY after all the record updates have been # processsed. - for record in itertools.chain(address_adds, other_adds): - self.cache.add(record) + self.cache.add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. - for record in removes: - self.cache.remove(record) + self.cache.remove_records(removes) def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: """Provide an answer to a service type enumeration query. From 1eaeef2d6f07efba67e91699529f8361226233ce Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 22:34:36 -1000 Subject: [PATCH 185/608] Fix flakey test_update_record (#470) --- zeroconf/test.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index d471db3e..2eb9c770 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1980,9 +1980,6 @@ def mock_incoming_msg( zeroconf.handle_response( mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120) ) - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) - ) called_with_refresh_time_check = False @@ -1995,6 +1992,12 @@ def _mock_get_expiration_time(self, percent): # Set an expire time that will force a refresh with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + zeroconf.handle_response( + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120) + ) + # Add the last record after updating the first one + # to ensure the service_add_event only gets set + # after the update zeroconf.handle_response( mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) ) From 00af5adc4be76afd23135d37653119f45c57a531 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 23:09:02 -1000 Subject: [PATCH 186/608] Reduce branching in service_type_name (#472) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 179e4221..d2be47e1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -280,7 +280,7 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis :return: fully qualified service name (eg: _http._tcp.local.) """ - if type_.endswith(_TCP_PROTOCOL_LOCAL_TRAILER) or type_.endswith(_NONTCP_PROTOCOL_LOCAL_TRAILER): + if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] has_protocol = True From d0f5a60275ccf810407055c63ca9080fa6654443 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 23:10:04 -1000 Subject: [PATCH 187/608] Add test coverage to ensure ServiceInfo rejects expired records (#468) --- zeroconf/test.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/zeroconf/test.py b/zeroconf/test.py index 2eb9c770..11979af7 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1711,6 +1711,46 @@ def test_service_info_rejects_non_matching_updates(self): assert new_address not in info.addresses zc.close() + def test_service_info_rejects_expired_records(self): + """Verify records that are expired are rejected.""" + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + # Expired record + expired_record = r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ) + expired_record.created = 1000 + expired_record._expiration_time = 1000 + info.update_record(zc, now, expired_record) + assert info.properties[b"ci"] == b"2" + zc.close() + def test_get_info_partial(self): zc = r.Zeroconf(interfaces=['127.0.0.1']) From b8534130ec31a6be191fcc60615ab2fd02fd8d7a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 23:18:34 -1000 Subject: [PATCH 188/608] Narrow exception catch in DNSAddress.__repr__ to only expected exceptions (#473) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d2be47e1..dcd477fa 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -597,7 +597,7 @@ def __repr__(self) -> str: socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address ) ) - except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions + except (ValueError, OSError): return self.to_string(str(self.address)) From ed53f6283265eb8fb506d4af8fb31bd4eaa7292b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 5 Jun 2021 23:55:49 -1000 Subject: [PATCH 189/608] Add support for updating multiple records at once to ServiceInfo (#474) - Adds `update_records` method to `ServiceInfo` --- zeroconf/__init__.py | 32 ++++++++++++++++++++++---------- zeroconf/test.py | 2 ++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dcd477fa..9a5e125e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1984,20 +1984,33 @@ def get_name(self) -> str: return self.name[: len(self.name) - len(self.type) - 1] def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record.""" - if record is None or record.is_expired(now): - return + """Updates service information from a DNS record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + if record is not None: + self.update_records(zc, now, [record]) - self._process_record(record) + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Updates service information from a DNS record.""" + update_addresses = False + for record in records: + if isinstance(record, DNSService): + update_addresses = True + self._process_record(record, now) # Only update addresses if the DNSService (.server) has changed - if not isinstance(record, DNSService): + if not update_addresses: return - for cached_record in self._get_address_records_from_cache(zc): - self._process_record(cached_record) + for record in self._get_address_records_from_cache(zc): + self._process_record(record, now) + + def _process_record(self, record: DNSRecord, now: float) -> None: + if record.is_expired(now): + return - def _process_record(self, record: DNSRecord) -> None: if isinstance(record, DNSAddress): if record.key == self.server_key and record.address not in self._addresses: self._addresses.append(record.address) @@ -2087,8 +2100,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: record_updates.append(cached_txt_record) - for record in record_updates: - self.update_record(zc, now, record) + self.update_records(zc, now, record_updates) return self._is_complete @property diff --git a/zeroconf/test.py b/zeroconf/test.py index 11979af7..02195240 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1625,6 +1625,8 @@ def test_service_info_rejects_non_matching_updates(self): info = ServiceInfo( service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] ) + # Verify backwards compatiblity with calling with None + info.update_record(zc, now, None) # Matching updates info.update_record( zc, From b0c0cdc6779dc095cf03ebd92652af69800b7bca Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 12:59:34 -1000 Subject: [PATCH 190/608] Fix AsyncServiceInfo.async_request not waiting long enough (#480) - The call to async_wait should have been in milliseconds, but the time was being passed in seconds which resulted in waiting 1000x shorter --- zeroconf/asyncio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 65771e8a..52b8f188 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -139,7 +139,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: next_ = now + delay delay *= 2 - await aiozc.async_wait((min(next_, last) - now) / 1000) + await aiozc.async_wait(min(next_, last) - now) now = current_time_millis() finally: aiozc.zeroconf.remove_listener(self) From 849e9bc792c6cc77b879b4761195192bea1720ce Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 13:05:29 -1000 Subject: [PATCH 191/608] Provide a helper function to convert milliseconds to seconds (#481) --- zeroconf/__init__.py | 9 +++++++-- zeroconf/asyncio.py | 7 ++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9a5e125e..9bb488b5 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -227,6 +227,11 @@ def current_time_millis() -> float: return time.time() * 1000 +def millis_to_seconds(millis: float) -> float: + """Convert milliseconds to seconds.""" + return millis / 1000.0 + + def _is_v6_address(addr: bytes) -> bool: return len(addr) == 16 @@ -539,7 +544,7 @@ def get_expiration_time(self, percent: int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" - return max(0, (self._expiration_time - now) / 1000.0) + return max(0, millis_to_seconds(self._expiration_time - now)) def is_expired(self, now: float) -> bool: """Returns true if this record has expired.""" @@ -2690,7 +2695,7 @@ def wait(self, timeout: float) -> None: """Calling thread waits for a given number of milliseconds or until notified.""" with self.condition: - self.condition.wait(timeout / 1000.0) + self.condition.wait(millis_to_seconds(timeout)) def notify_all(self) -> None: """Notifies all waiting threads""" diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 52b8f188..1f159c92 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -44,6 +44,7 @@ _UNREGISTER_TIME, current_time_millis, instance_name_from_service_info, + millis_to_seconds, ) @@ -248,7 +249,7 @@ async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: """Send a broadcasts to announce a service at intervals.""" for i in range(3): if i != 0: - await asyncio.sleep(interval / 1000) + await asyncio.sleep(millis_to_seconds(interval)) self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl)) async def async_register_service( @@ -278,7 +279,7 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b self._raise_on_name_conflict(info) for i in range(3): if i != 0: - await asyncio.sleep(_CHECK_TIME / 1000) + await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) self.sender.send(self.zeroconf.generate_service_query(info)) self._raise_on_name_conflict(info) @@ -333,7 +334,7 @@ async def async_get_service_info( async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(self.async_notify.event.wait(), timeout / 1000) + await asyncio.wait_for(self.async_notify.event.wait(), millis_to_seconds(timeout)) async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: """Adds a listener for a particular service type. This object From 8da00caf31e007153e10a8038a0a484edea03c2f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 13:25:11 -1000 Subject: [PATCH 192/608] ServiceBrowser must recheck for handlers to call when holding condition (#477) --- zeroconf/__init__.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9bb488b5..a842e806 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1785,17 +1785,33 @@ def cancel(self) -> None: super().cancel() self.join() + def _wait_for_next_event(self) -> None: + """Wait for the next handler or time to send queries.""" + # If there are handlers to call + # we want to process them right away + if self._handlers_to_call: + return + + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + + if next_time <= now: + return + + with self.zc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + self.zc.condition.wait(millis_to_seconds(next_time - now)) + def run(self) -> None: """Run the browser thread.""" super().run() while True: - if not self._handlers_to_call: - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - now = current_time_millis() - if next_time > now: - self.zc.wait(next_time - now) - + self._wait_for_next_event() if self.zc.done or self.done: return From 393910b67ac667a660ee9351cc8f94310937f654 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 13:30:10 -1000 Subject: [PATCH 193/608] Switch from using an asyncio.Event to asyncio.Condition for waiting (#482) --- zeroconf/asyncio.py | 56 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 12 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 1f159c92..0227949e 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -55,6 +55,35 @@ def _get_best_available_queue() -> queue.Queue: return queue.Queue() +# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed +async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: + """Wait for a condition or timeout.""" + loop = asyncio.get_event_loop() + future = loop.create_future() + + def _handle_timeout() -> None: + if not future.done(): + future.set_result(None) + + timer_handle = loop.call_later(timeout, _handle_timeout) + condition_wait = loop.create_task(condition.wait()) + + def _handle_wait_complete(_: asyncio.Task) -> None: + if not future.done(): + future.set_result(None) + + condition_wait.add_done_callback(_handle_wait_complete) + + try: + await future + finally: + timer_handle.cancel() + if not condition_wait.done(): + condition_wait.cancel() + with contextlib.suppress(asyncio.CancelledError): + await condition_wait + + class _AsyncSender(threading.Thread): """A thread to handle sending DNSOutgoing for asyncio.""" @@ -87,19 +116,19 @@ def run(self) -> None: class AsyncNotifyListener(NotifyListener): """A NotifyListener that async code can use to wait for events.""" - def __init__(self) -> None: + def __init__(self, aiozc: 'AsyncZeroconf') -> None: """Create an event for async listeners to wait for.""" - self.event = asyncio.Event() + self.aiozc = aiozc self.loop = asyncio.get_event_loop() def notify_all(self) -> None: """Schedule an async_notify_all.""" - self.loop.call_soon_threadsafe(self.async_notify_all) + self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all()) - def async_notify_all(self) -> None: + async def _async_notify_all(self) -> None: """Notify all async listeners.""" - self.event.set() - self.event.clear() + async with self.aiozc.condition: + self.aiozc.condition.notify_all() class AsyncServiceListener: @@ -169,10 +198,12 @@ def __init__( super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore self._browser_task = asyncio.ensure_future(self.async_run()) - def cancel(self) -> None: + async def async_cancel(self) -> None: """Cancel the browser.""" - super().cancel() + self.cancel() self._browser_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._browser_task async def async_run(self) -> None: """Run the browser task.""" @@ -240,10 +271,11 @@ def __init__( apple_p2p=apple_p2p, ) self.loop = asyncio.get_event_loop() - self.async_notify = AsyncNotifyListener() + self.async_notify = AsyncNotifyListener(self) self.zeroconf.add_notify_listener(self.async_notify) self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} self.sender = _AsyncSender(self.zeroconf) + self.condition = asyncio.Condition() async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -333,8 +365,8 @@ async def async_get_service_info( async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" - with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for(self.async_notify.event.wait(), millis_to_seconds(timeout)) + async with self.condition: + await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout)) async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: """Adds a listener for a particular service type. This object @@ -346,7 +378,7 @@ async def async_add_service_listener(self, type_: str, listener: AsyncServiceLis async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None: """Removes a listener from the set that is currently listening.""" if listener in self.async_browsers: - self.async_browsers[listener].cancel() + await self.async_browsers[listener].async_cancel() del self.async_browsers[listener] async def async_remove_all_service_listeners(self) -> None: From 9c06ce15db31ebffe3a556896393d48cb786b5d9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 14:04:36 -1000 Subject: [PATCH 194/608] Relocate ServiceBrowser wait time calculation to seperate function (#484) - Eliminate the need to duplicate code between the ServiceBrowser and AsyncServiceBrowser to calculate the wait time. --- zeroconf/__init__.py | 49 +++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a842e806..86b828d3 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1754,6 +1754,22 @@ def generate_ready_queries(self) -> Optional[DNSOutgoing]: self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) return out + def _seconds_to_wait(self) -> Optional[float]: + """Returns the number of seconds to wait for the next event.""" + # If there are handlers to call + # we want to process them right away + if self._handlers_to_call: + return None + + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + + if next_time <= now: + return None + + return millis_to_seconds(next_time - now) + class ServiceBrowser(_ServiceBrowserBase, threading.Thread): """Used to browse for a service of a specific type. @@ -1785,33 +1801,20 @@ def cancel(self) -> None: super().cancel() self.join() - def _wait_for_next_event(self) -> None: - """Wait for the next handler or time to send queries.""" - # If there are handlers to call - # we want to process them right away - if self._handlers_to_call: - return - - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - now = current_time_millis() - - if next_time <= now: - return - - with self.zc.condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - self.zc.condition.wait(millis_to_seconds(next_time - now)) - def run(self) -> None: """Run the browser thread.""" super().run() while True: - self._wait_for_next_event() + timeout = self._seconds_to_wait() + if timeout: + with self.zc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + self.zc.condition.wait(timeout) + if self.zc.done or self.done: return From 960693628006e23fd13fcaefef915ca0c84401b9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 14:14:17 -1000 Subject: [PATCH 195/608] AsyncServiceBrowser must recheck for handlers to call when holding condition (#483) - There was a short race condition window where the AsyncServiceBrowser could add to _handlers_to_call in the Engine thread, have the condition notify_all called, but since the AsyncServiceBrowser was not yet holding the condition it would not know to stop waiting and process the handlers to call. --- zeroconf/asyncio.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 0227949e..61119747 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -209,12 +209,15 @@ async def async_run(self) -> None: """Run the browser task.""" self.run() while True: - if not self._handlers_to_call: - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - now = current_time_millis() - if next_time > now: - await self.aiozc.async_wait(next_time - now) + timeout = self._seconds_to_wait() + if timeout: + async with self.aiozc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + await wait_condition_or_timeout(self.aiozc.condition, timeout) out = self.generate_ready_queries() if out: From 0a69aa0d37e13cb2c65ceb5cc3ab0fd7e9d34b22 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 14:18:46 -1000 Subject: [PATCH 196/608] RecordUpdateListener now uses update_records instead of update_record (#419) --- zeroconf/__init__.py | 159 ++++++++++++++++++++++++++------------- zeroconf/test.py | 54 +++++++++++++ zeroconf/test_asyncio.py | 7 +- 3 files changed, 167 insertions(+), 53 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 86b828d3..b010c923 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -35,8 +35,9 @@ import time import warnings from collections import OrderedDict +from contextlib import contextmanager from types import TracebackType # noqa # used in type hints -from typing import Dict, Iterable, List, Optional, Type, Union, cast +from typing import Dict, Generator, Iterable, List, Optional, Type, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints import ifaddr @@ -1424,8 +1425,8 @@ def run(self) -> None: now = current_time_millis() if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms: self._last_cache_cleanup = now - for record in self.zc.cache.expire(now): - self.zc.update_record(now, record) + with self.zc.update_records(now, list(self.zc.cache.expire(now))): + pass self.socketpair[0].close() self.socketpair[1].close() @@ -1548,8 +1549,37 @@ def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistratio class RecordUpdateListener: - def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: - raise NotImplementedError() + def update_record( # pylint: disable=no-self-use + self, zc: 'Zeroconf', now: float, record: DNSRecord + ) -> None: + """Update a single record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + raise RuntimeError("update_record is deprecated and will be removed in a future version.") + + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Update multiple records in one shot. + + All records that are received in a single packet are passed + to update_records. + + This implementation is a compatiblity shim to ensure older code + that uses RecordUpdateListener as a base class will continue to + get calls to update_record. This method will raise + NotImplementedError in a future version. + + At this point the cache will not have the new records + """ + for record in records: + self.update_record(zc, now, record) + + def update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + """ class ServiceListener: @@ -1601,6 +1631,7 @@ def __init__( current_time = current_time_millis() self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} + self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] self._service_state_changed = Signal() @@ -1649,30 +1680,32 @@ def _record_matching_type(self, record: DNSRecord) -> Optional[str]: """Return the type if the record matches one of the types we are browsing.""" return next((type_ for type_ in self.types if record.name.endswith(type_)), None) - def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None: - """Callback invoked by Zeroconf when new information arrives. - - Updates information required by browser in the Zeroconf cache. - - Ensures that there is are no unecessary duplicates in the list - - """ - - def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> None: - - # Code to ensure we only do a single update message - # Precedence is; Added, Remove, Update - key = (name, type_) - if ( - state_change is ServiceStateChange.Added - or ( - state_change is ServiceStateChange.Removed - and self._handlers_to_call.get(key) != ServiceStateChange.Added - ) - or (state_change is ServiceStateChange.Updated and key not in self._handlers_to_call) - ): - self._handlers_to_call[key] = state_change + def _enqueue_callback( + self, + state_change: ServiceStateChange, + type_: str, + name: str, + ) -> None: + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + key = (name, type_) + if ( + state_change is ServiceStateChange.Added + or ( + state_change is ServiceStateChange.Removed + and self._pending_handlers.get(key) != ServiceStateChange.Added + ) + or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) + ): + self._pending_handlers[key] = state_change + def _process_record_update( + self, + zc: 'Zeroconf', + now: float, + record: DNSRecord, + ) -> None: + """Process a single record update from a batch of updates.""" expired = record.is_expired(now) if isinstance(record, DNSPointer): @@ -1683,10 +1716,10 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> old_record = services_by_type.get(service_key) if old_record is None: services_by_type[service_key] = record - enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) elif expired: del services_by_type[service_key] - enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) + self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: old_record.reset_ttl(record) expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) @@ -1711,14 +1744,32 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> for service in self.zc.cache.entries_with_server(record.name): type_ = self._record_matching_type(service) if type_: - enqueue_callback(ServiceStateChange.Updated, type_, service.name) + self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) break return type_ = self._record_matching_type(record) if type_: - enqueue_callback(ServiceStateChange.Updated, type_, record.name) + self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) + + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Callback invoked by Zeroconf when new information arrives. + + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unecessary duplicates in the list. + """ + for record in records: + self._process_record_update(zc, now, record) + + def update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + """ + self._handlers_to_call.update(self._pending_handlers) + self._pending_handlers.clear() def cancel(self) -> None: """Cancel the browser.""" @@ -1825,9 +1876,7 @@ def run(self) -> None: if not self._handlers_to_call: continue - with self.zc._handlers_lock: # pylint: disable=protected-access - (name_type, state_change) = self._handlers_to_call.popitem(False) - + (name_type, state_change) = self._handlers_to_call.popitem(False) self._service_state_changed.fire( zeroconf=self.zc, service_type=name_type[1], @@ -2689,11 +2738,6 @@ def __init__( self.condition = threading.Condition() - # Ensure we create the lock before - # we add the listener as we could get - # a message before the lock is created. - self._handlers_lock = threading.Lock() # ensure we process a full message in one go - self.engine = Engine(self) self.listener = Listener(self) if not unicast: @@ -2902,12 +2946,17 @@ def add_listener( answer the question(s).""" now = current_time_millis() self.listeners.append(listener) + records = [] if question is not None: questions = [question] if isinstance(question, DNSQuestion) else question for single_question in questions: for record in self.cache.entries_with_name(single_question.name): if single_question.answered_by(record) and not record.is_expired(now): - listener.update_record(self, now, record) + records.append(record) + + if records: + listener.update_records(self, now, records) + listener.update_records_complete() self.notify_all() def remove_listener(self, listener: RecordUpdateListener) -> None: @@ -2918,14 +2967,23 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions log.exception('Unknown error, possibly benign: %r', e) - def update_record(self, now: float, rec: DNSRecord) -> None: + @contextmanager + def update_records(self, now: float, rec: List[DNSRecord]) -> Generator: """Used to notify listeners of new information that has updated - a record.""" - for listener in self.listeners: - listener.update_record(self, now, rec) - self.notify_all() + a record. + + This method must be called before the cache is updated. + """ + try: + for listener in self.listeners: + listener.update_records(self, now, rec) + yield + finally: + for listener in self.listeners: + listener.update_records_complete() + self.notify_all() - def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many-branches + def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" updates = [] # type: List[DNSRecord] @@ -2967,10 +3025,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many if not updates and not address_adds and not other_adds and not removes: return - # Only hold the lock if we have updates - with self._handlers_lock: - for record in updates: - self.update_record(now, record) + with self.update_records(now, updates): # The cache adds must be processed AFTER we trigger # the updates since we compare existing data # with the new data and updating the cache @@ -2981,7 +3036,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many # otherwise a fetch of ServiceInfo may miss an address # because it thinks the cache is complete # - # The cache is processed under the lock to ensure + # The cache is processed under the context manager to ensure # that any ServiceBrowser that is going to call # zc.get_service_info will see the cached value # but ONLY after all the record updates have been diff --git a/zeroconf/test.py b/zeroconf/test.py index 02195240..27e03b53 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2552,3 +2552,57 @@ def on_service_state_change(zeroconf, service_type, state_change, name): assert not notify_called zc.close() + + +def test_legacy_record_update_listener(): + """Test a RecordUpdateListener that does not implement update_records.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(RuntimeError): + r.RecordUpdateListener().update_record( + zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + ) + + updates = [] + + class LegacyRecordUpdateListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: + nonlocal updates + updates.append(record) + + zc.add_listener(LegacyRecordUpdateListener(), None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + zc.wait(1) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.close() diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index e6bcf5a2..ddf3fbee 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -435,9 +435,14 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: await task task = await aiozc.async_unregister_service(new_info) await task + await aiozc.async_wait(1) await aiozc.async_close() - assert calls[0] == ('add', type_, registration_name) + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] @pytest.mark.asyncio From 49db96dae466a602662f4fde1537f62a8c8d3110 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 14:38:38 -1000 Subject: [PATCH 197/608] Enable test_integration_with_listener_class test on PyPy (#485) --- zeroconf/test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 27e03b53..64175e4f 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -9,7 +9,6 @@ import itertools import logging import os -import platform import socket import struct import threading @@ -1263,7 +1262,6 @@ def test_integration_with_subtype_and_listener(self): class ListenerTest(unittest.TestCase): - @pytest.mark.skipif(platform.python_implementation() == 'PyPy', reason="Flaky on PyPy") def test_integration_with_listener_class(self): service_added = Event() From 275765a4fd3b477b79163c04f6411709e14506b9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 15:28:44 -1000 Subject: [PATCH 198/608] Move threading daemon property into ServiceBrowser class (#486) --- zeroconf/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b010c923..6bd797d1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1620,7 +1620,6 @@ def __init__( for check_type_ in self.types: if not check_type_.endswith(service_type_name(check_type_, strict=False)): raise BadTypeInNameException - self.daemon = True self.zc = zc self.addr = addr self.port = port @@ -1841,6 +1840,7 @@ def __init__( ) -> None: threading.Thread.__init__(self) super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + self.daemon = True self.start() self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), From ef9334f1279d029752186bc6f4a1ebff6229bf5b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 16:34:14 -1000 Subject: [PATCH 199/608] Add AsyncServiceBrowser example (#487) --- examples/async_browser.py | 78 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 examples/async_browser.py diff --git a/examples/async_browser.py b/examples/async_browser.py new file mode 100644 index 00000000..b2a1916d --- /dev/null +++ b/examples/async_browser.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +""" Example of browsing for a service. + +The default is HTTP and HAP; use --find to search for all available services in the network +""" + +import argparse +import asyncio +import logging +from typing import cast + +from zeroconf import IPVersion, ServiceStateChange +from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf + + +def async_on_service_state_change( + zeroconf: AsyncZeroconf, service_type: str, name: str, state_change: ServiceStateChange +) -> None: + print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + if state_change is not ServiceStateChange.Added: + return + asyncio.ensure_future(async_display_service_info(zeroconf, service_type, name)) + + +async def async_display_service_info(zeroconf: AsyncZeroconf, service_type: str, name: str) -> None: + info = await zeroconf.async_get_service_info(service_type, name) + print("Info from zeroconf.get_service_info: %r" % (info)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + print(" Name: %s" % name) + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(" Server: %s" % (info.server,)) + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(" %s: %s" % (key, value)) + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + aiozc = AsyncZeroconf(ip_version=ip_version) + + services = ["_http._tcp.local.", "_hap._tcp.local."] + print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services) + aiobrowser = AsyncServiceBrowser(aiozc, services, handlers=[async_on_service_state_change]) + + loop = asyncio.get_event_loop() + try: + loop.run_forever() + except KeyboardInterrupt: + pass + finally: + loop.run_until_complete(aiobrowser.async_cancel()) + loop.run_until_complete(aiozc.async_close()) From 69880ae6ca4d4f0a7d476b0271b89adea92b9389 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 16:42:57 -1000 Subject: [PATCH 200/608] Lint before testing in the CI (#488) --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 37d4bff6..b766fcf7 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ env: cp requirements-dev.txt ./env/requirements.built .PHONY: ci -ci: test_coverage lint +ci: lint test_coverage .PHONY: lint lint: $(LINT_TARGETS) From f0c02a02c1a2d7c914c62479bad4957b06471661 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 17:29:31 -1000 Subject: [PATCH 201/608] Remove unused __ne__ code from Python 2 era (#492) --- zeroconf/__init__.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 6bd797d1..9560e649 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -455,10 +455,6 @@ def __eq__(self, other: Any) -> bool: and isinstance(other, DNSEntry) ) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - @staticmethod def get_class_(class_: int) -> str: """Class accessor""" @@ -520,10 +516,6 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" raise AbstractMethodException - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def suppressed_by(self, msg: 'DNSIncoming') -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" @@ -591,10 +583,6 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address ) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" try: @@ -630,10 +618,6 @@ def __eq__(self, other: Any) -> bool: and self.os == other.os ) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" return self.to_string(self.cpu + " " + self.os) @@ -655,10 +639,6 @@ def __eq__(self, other: Any) -> bool: """Tests equality on alias""" return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" return self.to_string(self.alias) @@ -681,10 +661,6 @@ def __eq__(self, other: Any) -> bool: """Tests equality on text""" return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" if len(self.text) > 10: @@ -731,10 +707,6 @@ def __eq__(self, other: Any) -> bool: and DNSEntry.__eq__(self, other) ) - def __ne__(self, other: Any) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" return self.to_string("%s:%s" % (self.server, self.port)) @@ -2227,10 +2199,6 @@ def __eq__(self, other: object) -> bool: """Tests equality of service name""" return isinstance(other, ServiceInfo) and other.name == self.name - def __ne__(self, other: object) -> bool: - """Non-equality test""" - return not self.__eq__(other) - def __repr__(self) -> str: """String representation""" return '%s(%s)' % ( From 20f8b3d6fb8d117b0c3c794c4075a00e117e3f31 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 17:29:41 -1000 Subject: [PATCH 202/608] Update internal version check to match docs (3.6+) (#491) --- zeroconf/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9560e649..5efdf979 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -60,11 +60,12 @@ "IPVersion", ] -if sys.version_info <= (3, 4): +if sys.version_info <= (3, 6): raise ImportError( ''' -Python version > 3.4 required for python-zeroconf. +Python version > 3.6 required for python-zeroconf. If you need support for Python 2 or Python 3.3-3.4 please use version 19.1 +If you need support for Python 3.5 please use version 0.28.0 ''' ) From 38e4b42b847e700db52bc51973210efc485d8c23 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 20:00:23 -1000 Subject: [PATCH 203/608] Make a base class for DNSIncoming and DNSOutgoing (#497) --- zeroconf/__init__.py | 32 ++++++++++++++++++++------------ zeroconf/test.py | 2 ++ 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5efdf979..dcf00dc9 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -713,18 +713,34 @@ def __repr__(self) -> str: return self.to_string("%s:%s" % (self.server, self.port)) -class DNSIncoming(QuietLogger): +class DNSMessage: + """A base class for DNS messages.""" + + def __init__(self, flags: int) -> None: + """Construct a DNS message.""" + self.flags = flags + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + +class DNSIncoming(DNSMessage, QuietLogger): """Object representation of an incoming DNS packet""" def __init__(self, data: bytes) -> None: """Constructor from string holding bytes of packet""" + super().__init__(0) self.offset = 0 self.data = data self.questions = [] # type: List[DNSQuestion] self.answers = [] # type: List[DNSRecord] self.id = 0 - self.flags = 0 # type: int self.num_questions = 0 self.num_answers = 0 self.num_authorities = 0 @@ -846,14 +862,6 @@ def read_others(self) -> None: if rec is not None: self.answers.append(rec) - def is_query(self) -> bool: - """Returns true if this is a query""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY - - def is_response(self) -> bool: - """Returns true if this is a response""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE - def read_utf(self, offset: int, length: int) -> str: """Reads a UTF-8 string of a given length from the packet""" return str(self.data[offset : offset + length], 'utf-8', 'replace') @@ -892,15 +900,15 @@ def read_name(self) -> str: return result -class DNSOutgoing: +class DNSOutgoing(DNSMessage): """Object representation of an outgoing packet""" def __init__(self, flags: int, multicast: bool = True) -> None: + super().__init__(flags) self.finished = False self.id = 0 self.multicast = multicast - self.flags = flags self.packets_data = [] # type: List[bytes] # these 3 are per-packet -- see also reset_for_next_packet() diff --git a/zeroconf/test.py b/zeroconf/test.py index 64175e4f..31a4d851 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -936,6 +936,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): # query query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) + assert query.is_query() is True query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) @@ -1463,6 +1464,7 @@ def update_service(self, zc, type_, name) -> None: def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + assert generated.is_response() is True if service_state_change == r.ServiceStateChange.Removed: ttl = 0 From e2908c6c89802ba7a0ea51ac351da40bce3f1cb6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 21:28:02 -1000 Subject: [PATCH 204/608] Ensure packets are properly seperated when exceeding maximum size (#498) - Ensure that questions that exceed the max packet size are moved to the next packet. This fixes DNSQuestions being sent in multiple packets in violation of: https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - Ensure only one resource record is sent when a record exceeds _MAX_MSG_TYPICAL https://datatracker.ietf.org/doc/html/rfc6762#section-17 --- zeroconf/__init__.py | 124 ++++++++++++++++++++++++++++++------------- zeroconf/test.py | 112 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 37 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dcf00dc9..bd5c17f2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -915,6 +915,7 @@ def __init__(self, flags: int, multicast: bool = True) -> None: self.names = {} # type: Dict[str, int] self.data = [] # type: List[bytes] self.size = 12 + self.allow_long = True self.state = self.State.init @@ -927,6 +928,7 @@ def reset_for_next_packet(self) -> None: self.names = {} self.data = [] self.size = 12 + self.allow_long = True def __repr__(self) -> str: return '' % ', '.join( @@ -1118,13 +1120,15 @@ def write_name(self, name: str) -> None: # this is the end of a name self.write_byte(0) - def write_question(self, question: DNSQuestion) -> None: + def write_question(self, question: DNSQuestion) -> bool: """Writes a question to the packet""" + start_data_length, start_size = len(self.data), self.size self.write_name(question.name) self.write_short(question.type) self.write_short(question.class_) + return self._check_data_limit_or_rollback(start_data_length, start_size) - def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) -> bool: + def write_record(self, record: DNSRecord, now: float) -> bool: """Writes a record (answer, authoritative answer, additional) to the packet. Returns True on success, or False if we did not (either because the packet was already finished or because the record does @@ -1152,19 +1156,26 @@ def write_record(self, record: DNSRecord, now: float, allow_long: bool = False) # Here we replace the 0 length short we wrote # before with the actual length self.replace_short(index, length) - len_limit = _MAX_MSG_ABSOLUTE if allow_long else _MAX_MSG_TYPICAL + return self._check_data_limit_or_rollback(start_data_length, start_size) - # if we go over, then rollback and quit - if self.size > len_limit: - while len(self.data) > start_data_length: - self.data.pop() - self.size = start_size + def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: + """Check data limit, if we go over, then rollback and return False.""" + len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL + self.allow_long = False - rollback_names = [name for name, idx in self.names.items() if idx >= start_size] - for name in rollback_names: - del self.names[name] - return False - return True + if self.size <= len_limit: + return True + + log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) + + while len(self.data) > start_data_length: + self.data.pop() + self.size = start_size + + rollback_names = [name for name, idx in self.names.items() if idx >= start_size] + for name in rollback_names: + del self.names[name] + return False def packet(self) -> bytes: """Returns a bytestring containing the first packet's bytes. @@ -1181,6 +1192,38 @@ def packet(self) -> bytes: ) return packets[0] + def _write_questions_from_offset(self, questions_offset: int) -> int: + questions_written = 0 + for question in self.questions[questions_offset:]: + if not self.write_question(question): + break + questions_written += 1 + return questions_written + + def _write_answers_from_offset(self, answer_offset: int) -> int: + answers_written = 0 + for answer, time_ in self.answers[answer_offset:]: + if not self.write_record(answer, time_): + break + answers_written += 1 + return answers_written + + def _write_authorities_from_offset(self, authority_offset: int) -> int: + authorities_written = 0 + for authority in self.authorities[authority_offset:]: + if not self.write_record(authority, 0): + break + authorities_written += 1 + return authorities_written + + def _write_additionals_from_offset(self, additional_offset: int) -> int: + additionals_written = 0 + for additional in self.additionals[additional_offset:]: + if not self.write_record(additional, 0): + break + additionals_written += 1 + return additionals_written + def packets(self) -> List[bytes]: """Returns a list of bytestrings containing the packets' bytes @@ -1194,6 +1237,7 @@ def packets(self) -> List[bytes]: if self.state == self.State.finished: return self.packets_data + questions_offset = 0 answer_offset = 0 authority_offset = 0 additional_offset = 0 @@ -1203,32 +1247,31 @@ def packets(self) -> List[bytes]: while ( first_time + or questions_offset < len(self.questions) or answer_offset < len(self.answers) or authority_offset < len(self.authorities) or additional_offset < len(self.additionals) ): first_time = False - log.debug("offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset) - log.debug("lengths = %d, %d, %d", len(self.answers), len(self.authorities), len(self.additionals)) - - additionals_written = 0 - authorities_written = 0 - answers_written = 0 - questions_written = 0 - for question in self.questions: - self.write_question(question) - questions_written += 1 - allow_long = True # at most one answer is allowed to be a long packet - for answer, time_ in self.answers[answer_offset:]: - if self.write_record(answer, time_, allow_long): - answers_written += 1 - allow_long = False - for authority in self.authorities[authority_offset:]: - if self.write_record(authority, 0): - authorities_written += 1 - for additional in self.additionals[additional_offset:]: - if self.write_record(additional, 0): - additionals_written += 1 + log.debug( + "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + log.debug( + "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", + len(self.questions), + len(self.answers), + len(self.authorities), + len(self.additionals), + ) + + questions_written = self._write_questions_from_offset(questions_offset) + answers_written = self._write_answers_from_offset(answer_offset) + authorities_written = self._write_authorities_from_offset(authority_offset) + additionals_written = self._write_additionals_from_offset(additional_offset) self.insert_short_at_start(additionals_written) self.insert_short_at_start(authorities_written) @@ -1242,12 +1285,19 @@ def packets(self) -> List[bytes]: self.packets_data.append(b''.join(self.data)) self.reset_for_next_packet() + questions_offset += questions_written answer_offset += answers_written authority_offset += authorities_written additional_offset += additionals_written - log.debug("now offsets = %d, %d, %d", answer_offset, authority_offset, additional_offset) - if (answers_written + authorities_written + additionals_written) == 0 and ( - len(self.answers) + len(self.authorities) + len(self.additionals) + log.debug( + "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( + len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) ) > 0: log.warning("packets() made no progress adding records; returning") break diff --git a/zeroconf/test.py b/zeroconf/test.py index 31a4d851..b0791a34 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -313,6 +313,118 @@ def test_dns_hinfo(self): generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet) + def test_many_questions(self): + """Test many questions get seperated into multiple packets.""" + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + questions = [] + for i in range(100): + question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 100 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) < r._MAX_MSG_TYPICAL + assert len(packets[1]) < r._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 85 + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 15 + + def test_only_one_answer_can_by_large(self): + """Test that only the first answer in each packet can be large. + + https://datatracker.ietf.org/doc/html/rfc6762#section-17 + """ + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + query = r.DNSIncoming(r.DNSOutgoing(r._FLAGS_QR_QUERY).packet()) + for i in range(3): + generated.add_answer( + query, + r.DNSText( + "zoom._hap._tcp.local.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + 1200, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, + ), + ) + generated.add_answer( + query, + r.DNSService( + "testname1.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + ) + assert len(generated.answers) == 4 + + packets = generated.packets() + assert len(packets) == 4 + assert len(packets[0]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[0]) > r._MAX_MSG_TYPICAL + + assert len(packets[1]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[1]) > r._MAX_MSG_TYPICAL + + assert len(packets[2]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[2]) > r._MAX_MSG_TYPICAL + + assert len(packets[3]) <= r._MAX_MSG_TYPICAL + + for packet in packets: + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + def test_questions_do_not_end_up_every_packet(self): + """Test that questions are not sent again when multiple packets are needed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + Sometimes a Multicast DNS querier will already have too many answers + to fit in the Known-Answer Section of its query packets.... It MUST + immediately follow the packet with another query packet containing no + questions and as many more Known-Answer records as will fit. + """ + + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + for i in range(35): + question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + answer = r.DNSService( + f"testname{i}.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + f"foo{i}.local.", + ) + generated.add_answer_at_time(answer, 0) + + assert len(generated.questions) == 35 + assert len(generated.answers) == 35 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) <= r._MAX_MSG_TYPICAL + assert len(packets[1]) <= r._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 35 + assert len(parsed1.answers) == 33 + + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 2 + class PacketForm(unittest.TestCase): def test_transaction_id(self): From f04a2eb43745eba7c43c9c56179ed1fceb992bd8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 6 Jun 2021 21:45:57 -1000 Subject: [PATCH 205/608] Set the TC bit for query packets where the known answers span multiple packets (#494) --- zeroconf/__init__.py | 45 ++++++++++++++++++++--------- zeroconf/test.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index bd5c17f2..c362d16b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1224,6 +1224,17 @@ def _write_additionals_from_offset(self, additional_offset: int) -> int: additionals_written += 1 return additionals_written + def _has_more_to_add( + self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int + ) -> bool: + """Check if all questions, answers, authority, and additionals have been written to the packet.""" + return ( + questions_offset < len(self.questions) + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ) + def packets(self) -> List[bytes]: """Returns a list of bytestrings containing the packets' bytes @@ -1241,16 +1252,11 @@ def packets(self) -> List[bytes]: answer_offset = 0 authority_offset = 0 additional_offset = 0 - # we have to at least write out the question first_time = True - while ( - first_time - or questions_offset < len(self.questions) - or answer_offset < len(self.answers) - or authority_offset < len(self.authorities) - or additional_offset < len(self.additionals) + while first_time or self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset ): first_time = False log.debug( @@ -1277,13 +1283,6 @@ def packets(self) -> List[bytes]: self.insert_short_at_start(authorities_written) self.insert_short_at_start(answers_written) self.insert_short_at_start(questions_written) - self.insert_short_at_start(self.flags) - if self.multicast: - self.insert_short_at_start(0) - else: - self.insert_short_at_start(self.id) - self.packets_data.append(b''.join(self.data)) - self.reset_for_next_packet() questions_offset += questions_written answer_offset += answers_written @@ -1296,6 +1295,24 @@ def packets(self) -> List[bytes]: authority_offset, additional_offset, ) + + if self.is_query() and self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + log.debug("Setting TC flag") + self.insert_short_at_start(self.flags | _FLAGS_TC) + else: + self.insert_short_at_start(self.flags) + + if self.multicast: + self.insert_short_at_start(0) + else: + self.insert_short_at_start(self.id) + + self.packets_data.append(b''.join(self.data)) + self.reset_for_next_packet() + if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) ) > 0: diff --git a/zeroconf/test.py b/zeroconf/test.py index b0791a34..060b80a0 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2613,6 +2613,73 @@ def test_dns_compression_rollback_for_corruption(): assert incoming.valid is True +def test_tc_bit_in_query_packet(): + """Verify the TC bit is set when known answers exceed the packet size.""" + out = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, r._TYPE_PTR, r._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.valid is True + + +def test_tc_bit_not_set_in_answer_packet(): + """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" + out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.flags & r._FLAGS_TC == 0 + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.flags & r._FLAGS_TC == 0 + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.valid is True + + @pytest.mark.parametrize( "errno,expected_result", [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], From 9b480bc1abb2c2702f60796f2edae76ce03ca5d4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 7 Jun 2021 08:27:36 -1000 Subject: [PATCH 206/608] Update changelog, move breaking changes to the top of the list (#501) --- README.rst | 154 ++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index 10472ba6..10b25f14 100644 --- a/README.rst +++ b/README.rst @@ -137,6 +137,152 @@ Changelog 0.32.0 (Unreleased) =================== +* Breaking change: Update internal version check to match docs (3.6+) (#491) @bdraco + + Python version eariler then 3.6 were likely broken with zeroconf + already, however the version is now explictly checked. + +* Breaking change: RecordUpdateListener now uses update_records instead of update_record (#419) @bdraco + + This allows the listener to receive all the records that have + been updated in a single transaction such as a packet or + cache expiry. + + update_record has been deprecated in favor of update_records + A compatibility shim exists to ensure classes that use + RecordUpdateListener as a base class continue to have + update_record called, however they should be updated + as soon as possible. + + A new method update_records_complete is now called on each + listener when all listeners have completed processing updates + and the cache has been updated. This allows ServiceBrowsers + to delay calling handlers until they are sure the cache + has been updated as its a common pattern to call for + ServiceInfo when a ServiceBrowser handler fires. + +* Breaking change: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco + + When manually creating a zeroconf.Engine object, it is no longer started automatically. + It must manually be started by calling .start() on the created object. + + The Engine thread is now started after all the listeners have been added to avoid a + race condition where packets could be missed at startup. + +* Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco + +* Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco + + Ensure that questions that exceed the max packet size are + moved to the next packet. This fixes DNSQuestions being + sent in multiple packets in violation of: + datatracker.ietf.org/doc/html/rfc6762#section-7.2 + + Ensure only one resource record is sent when a record + exceeds _MAX_MSG_TYPICAL + datatracker.ietf.org/doc/html/rfc6762#section-17 + +* Make a base class for DNSIncoming and DNSOutgoing (#497) @bdraco + +* Remove unused __ne__ code from Python 2 era (#492) @bdraco + +* Lint before testing in the CI (#488) @bdraco + +* Add AsyncServiceBrowser example (#487) @bdraco + +* Move threading daemon property into ServiceBrowser class (#486) @bdraco + +* Enable test_integration_with_listener_class test on PyPy (#485) @bdraco + +* AsyncServiceBrowser must recheck for handlers to call when holding condition (#483) + + There was a short race condition window where the AsyncServiceBrowser + could add to _handlers_to_call in the Engine thread, have the + condition notify_all called, but since the AsyncServiceBrowser was + not yet holding the condition it would not know to stop waiting + and process the handlers to call. + +* Relocate ServiceBrowser wait time calculation to seperate function (#484) @bdraco + + Eliminate the need to duplicate code between the ServiceBrowser + and AsyncServiceBrowser to calculate the wait time. + +* Switch from using an asyncio.Event to asyncio.Condition for waiting (#482) @bdraco + +* ServiceBrowser must recheck for handlers to call when holding condition (#477) @bdraco + + There was a short race condition window where the ServiceBrowser + could add to _handlers_to_call in the Engine thread, have the + condition notify_all called, but since the ServiceBrowser was + not yet holding the condition it would not know to stop waiting + and process the handlers to call. + +* Provide a helper function to convert milliseconds to seconds (#481) @bdraco + +* Fix AsyncServiceInfo.async_request not waiting long enough (#480) @bdraco + +* Add support for updating multiple records at once to ServiceInfo (#474) @bdraco + +* Narrow exception catch in DNSAddress.__repr__ to only expected exceptions (#473) @bdraco + +* Add test coverage to ensure ServiceInfo rejects expired records (#468) @bdraco + +* Reduce branching in service_type_name (#472) @bdraco + +* Fix flakey test_update_record (#470) @bdraco + +* Reduce branching in Zeroconf.handle_response (#467) @bdraco + +* Ensure PTR questions asked in uppercase are answered (#465) @bdraco + +* Clear cache between ServiceTypesQuery tests (#466) @bdraco + +* Break apart Zeroconf.handle_query to reduce branching (#462) @bdraco + +* Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek + +* Use constant for service type enumeration (#461) @bdraco + +* Reduce branching in Zeroconf.handle_response (#459) @bdraco + +* Reduce branching in Zeroconf.handle_query (#460) @bdraco + +* Enable pylint (#438) @bdraco + +* Trap OSError directly in Zeroconf.send instead of checking isinstance (#453) @bdraco + +* Disable protected-access on the ServiceBrowser usage of _handlers_lock (#452) @bdraco + +* Mark functions with too many branches in need of refactoring (#455) @bdraco + +* Disable pylint no-self-use check on abstract methods (#451) @bdraco + +* Use unique name in test_async_service_browser test (#450) @bdraco + +* Disable no-member check for WSAEINVAL false positive (#454) @bdraco + +* Mark methods used by asyncio without self use (#447) @bdraco + +* Extract _get_queue from zeroconf.asyncio._AsyncSender (#444) @bdraco + +* Fix redefining argument with the local name 'record' in ServiceInfo.update_record (#448) @bdraco + +* Remove unneeded-not in new_socket (#445) @bdraco + +* Disable broad except checks in places we still catch broad exceptions (#443) @bdraco + +* Merge _TYPE_CNAME and _TYPE_PTR comparison in DNSIncoming.read_others (#442) @bdraco + +* Convert unnecessary use of a comprehension to a list (#441) @bdraco + +* Remove unused now argument from ServiceInfo._process_record (#440) @bdraco + +* Disable pylint too-many-branches for functions that need refactoring (#439) @bdraco + +* Cleanup unused variables (#437) @bdraco + +* Cleanup unnecessary else after returns (#436) @bdraco + * Add zeroconf.asyncio to the docs (#434) @bdraco * Fix warning when generating sphinx docs (#432) @bdraco @@ -183,14 +329,6 @@ Changelog * Simplify DNSPointer processing in ServiceBrowser (#386) @bdraco -* Breaking change: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco - - When manually creating a zeroconf.Engine object, it is no longer started automatically. - It must manually be started by calling .start() on the created object. - - The Engine thread is now started after all the listeners have been added to avoid a - race condition where packets could be missed at startup. - * Ensure the cache is checked for name conflict after final service query with asyncio (#382) @bdraco * Complete ServiceInfo request as soon as all questions are answered (#380) @bdraco From bfca3b46fd9a395f387bd90b68c523a3ca84bde4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 09:19:24 -1000 Subject: [PATCH 207/608] Rename zeroconf.asyncio to zeroconf.aio (#503) - The asyncio name could shadow system asyncio in some cases. If zeroconf is in sys.path, this would result in loading zeroconf.asyncio when system asyncio was intended. - An `zeroconf.asyncio` shim module has been added that imports `zeroconf.aio` that was available in 0.31 to provide backwards compatibility in 0.32. This module will be removed in 0.33 to fix the underlying problem detailed in #502 --- Makefile | 6 +- docs/api.rst | 2 +- examples/async_browser.py | 2 +- examples/async_registration.py | 2 +- examples/async_service_info_request.py | 2 +- zeroconf/aio.py | 403 +++++++++++++++++++++ zeroconf/asyncio.py | 386 +------------------- zeroconf/test_aio.py | 469 +++++++++++++++++++++++++ zeroconf/test_asyncio.py | 456 +----------------------- 9 files changed, 890 insertions(+), 838 deletions(-) create mode 100644 zeroconf/aio.py create mode 100644 zeroconf/test_aio.py diff --git a/Makefile b/Makefile index b766fcf7..66bc85f1 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,7 @@ flake8: flake8 --max-line-length=$(MAX_LINE_LENGTH) setup.py examples zeroconf pylint: - pylint zeroconf/__init__.py zeroconf/asyncio.py + pylint zeroconf/__init__.py zeroconf/aio.py zeroconf/asyncio.py .PHONY: black_check black_check: @@ -39,10 +39,10 @@ mypy: mypy examples/*.py zeroconf/*.py test: - pytest -v zeroconf/test.py zeroconf/test_asyncio.py + pytest -v zeroconf/test.py zeroconf/test_aio.py zeroconf/test_asyncio.py test_coverage: - pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py zeroconf/test_asyncio.py + pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py zeroconf/test_aio.py zeroconf/test_asyncio.py autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf diff --git a/docs/api.rst b/docs/api.rst index 1704db5a..20c53727 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -6,7 +6,7 @@ python-zeroconf API reference :undoc-members: :show-inheritance: -.. automodule:: zeroconf.asyncio +.. automodule:: zeroconf.aio :members: :undoc-members: :show-inheritance: diff --git a/examples/async_browser.py b/examples/async_browser.py index b2a1916d..4a3861cb 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -11,7 +11,7 @@ from typing import cast from zeroconf import IPVersion, ServiceStateChange -from zeroconf.asyncio import AsyncServiceBrowser, AsyncZeroconf +from zeroconf.aio import AsyncServiceBrowser, AsyncZeroconf def async_on_service_state_change( diff --git a/examples/async_registration.py b/examples/async_registration.py index 53d14ce1..7e02ea7c 100644 --- a/examples/async_registration.py +++ b/examples/async_registration.py @@ -9,7 +9,7 @@ from typing import List from zeroconf import IPVersion -from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf +from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf async def register_services(infos: List[AsyncServiceInfo]) -> None: diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py index c0f953c2..838545ce 100644 --- a/examples/async_service_info_request.py +++ b/examples/async_service_info_request.py @@ -13,7 +13,7 @@ from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf -from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf +from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf HAP_TYPE = "_hap._tcp.local." diff --git a/zeroconf/aio.py b/zeroconf/aio.py new file mode 100644 index 00000000..61119747 --- /dev/null +++ b/zeroconf/aio.py @@ -0,0 +1,403 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" +import asyncio +import contextlib +import queue +import threading +from types import TracebackType # noqa # used in type hints +from typing import Awaitable, Callable, Dict, List, Optional, Type, Union + +from . import ( + DNSOutgoing, + IPVersion, + InterfaceChoice, + InterfacesType, + NonUniqueNameException, + NotifyListener, + ServiceInfo, + Zeroconf, + _BROWSER_TIME, + _CHECK_TIME, + _LISTENER_TIME, + _MDNS_PORT, + _REGISTER_TIME, + _ServiceBrowserBase, + _UNREGISTER_TIME, + current_time_millis, + instance_name_from_service_info, + millis_to_seconds, +) + + +def _get_best_available_queue() -> queue.Queue: + """Create the best available queue type.""" + if hasattr(queue, "SimpleQueue"): + return queue.SimpleQueue() # type: ignore # pylint: disable=all + return queue.Queue() + + +# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed +async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: + """Wait for a condition or timeout.""" + loop = asyncio.get_event_loop() + future = loop.create_future() + + def _handle_timeout() -> None: + if not future.done(): + future.set_result(None) + + timer_handle = loop.call_later(timeout, _handle_timeout) + condition_wait = loop.create_task(condition.wait()) + + def _handle_wait_complete(_: asyncio.Task) -> None: + if not future.done(): + future.set_result(None) + + condition_wait.add_done_callback(_handle_wait_complete) + + try: + await future + finally: + timer_handle.cancel() + if not condition_wait.done(): + condition_wait.cancel() + with contextlib.suppress(asyncio.CancelledError): + await condition_wait + + +class _AsyncSender(threading.Thread): + """A thread to handle sending DNSOutgoing for asyncio.""" + + def __init__(self, zc: 'Zeroconf'): + """Create the sender thread.""" + super().__init__() + self.zc = zc + self.queue = _get_best_available_queue() + self.start() + self.name = "AsyncZeroconfSender" + + def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + """Queue a send to be processed by the thread.""" + self.queue.put((out, addr, port)) + + def close(self) -> None: + """Close the instance.""" + self.queue.put(None) + self.join() + + def run(self) -> None: + """Runner that processes sends FIFO.""" + while True: + event = self.queue.get() + if event is None: + return + self.zc.send(*event) + + +class AsyncNotifyListener(NotifyListener): + """A NotifyListener that async code can use to wait for events.""" + + def __init__(self, aiozc: 'AsyncZeroconf') -> None: + """Create an event for async listeners to wait for.""" + self.aiozc = aiozc + self.loop = asyncio.get_event_loop() + + def notify_all(self) -> None: + """Schedule an async_notify_all.""" + self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all()) + + async def _async_notify_all(self) -> None: + """Notify all async listeners.""" + async with self.aiozc.condition: + self.aiozc.condition.notify_all() + + +class AsyncServiceListener: + def add_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def remove_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def update_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + +class AsyncServiceInfo(ServiceInfo): + """An async version of ServiceInfo.""" + + async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(aiozc.zeroconf): + return True + + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + aiozc.zeroconf.add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query(aiozc.zeroconf, now) + if not out.questions: + return self.load_from_cache(aiozc.zeroconf) + aiozc.sender.send(out) + next_ = now + delay + delay *= 2 + + await aiozc.async_wait(min(next_, last) - now) + now = current_time_millis() + finally: + aiozc.zeroconf.remove_listener(self) + + return True + + +class AsyncServiceBrowser(_ServiceBrowserBase): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + aiozc: 'AsyncZeroconf', + type_: Union[str, list], + handlers: Optional[Union[AsyncServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[AsyncServiceListener] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + self.aiozc = aiozc + super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore + self._browser_task = asyncio.ensure_future(self.async_run()) + + async def async_cancel(self) -> None: + """Cancel the browser.""" + self.cancel() + self._browser_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._browser_task + + async def async_run(self) -> None: + """Run the browser task.""" + self.run() + while True: + timeout = self._seconds_to_wait() + if timeout: + async with self.aiozc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + await wait_condition_or_timeout(self.aiozc.condition, timeout) + + out = self.generate_ready_queries() + if out: + self.aiozc.sender.send(out, addr=self.addr, port=self.port) + + if not self._handlers_to_call: + continue + + (name_type, state_change) = self._handlers_to_call.popitem(False) + self._service_state_changed.fire( + zeroconf=self.aiozc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + +class AsyncZeroconf: + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + + The async version is currently a wrapper around the sync version + with I/O being done in the executor for backwards compatibility. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + zc: Optional[Zeroconf] = None, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + self.zeroconf = zc or Zeroconf( + interfaces=interfaces, + unicast=unicast, + ip_version=ip_version, + apple_p2p=apple_p2p, + ) + self.loop = asyncio.get_event_loop() + self.async_notify = AsyncNotifyListener(self) + self.zeroconf.add_notify_listener(self.async_notify) + self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} + self.sender = _AsyncSender(self.zeroconf) + self.condition = asyncio.Condition() + + async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + """Send a broadcasts to announce a service at intervals.""" + for i in range(3): + if i != 0: + await asyncio.sleep(millis_to_seconds(interval)) + self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl)) + + async def async_register_service( + self, + info: ServiceInfo, + cooperating_responders: bool = False, + ) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`). + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + await self.async_check_service(info, cooperating_responders) + self.zeroconf.registry.add(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: + """Checks the network for a unique service name.""" + instance_name_from_service_info(info) + if cooperating_responders: + return + self._raise_on_name_conflict(info) + for i in range(3): + if i != 0: + await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) + self.sender.send(self.zeroconf.generate_service_query(info)) + self._raise_on_name_conflict(info) + + def _raise_on_name_conflict(self, info: ServiceInfo) -> None: + """Raise NonUniqueNameException if the ServiceInfo has a conflict.""" + if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): + raise NonUniqueNameException + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: + """Unregister a service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + self.zeroconf.registry.remove(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. + + The service will be broadcast in a task. This task is returned + and therefore can be awaited if necessary. + """ + self.zeroconf.registry.update(info) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + + def _close(self) -> None: + """Shutdown zeroconf and the sender.""" + self.sender.close() + self.zeroconf.remove_notify_listener(self.async_notify) + self.zeroconf.close() + + async def async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries.""" + await self.async_remove_all_service_listeners() + await self.loop.run_in_executor(None, self._close) + + async def async_get_service_info( + self, type_: str, name: str, timeout: int = 3000 + ) -> Optional[AsyncServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = AsyncServiceInfo(type_, name) + if await info.async_request(self, timeout): + return info + return None + + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + async with self.condition: + await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout)) + + async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + await self.async_remove_service_listener(listener) + self.async_browsers[listener] = AsyncServiceBrowser(self, type_, listener) + + async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.async_browsers: + await self.async_browsers[listener].async_cancel() + del self.async_browsers[listener] + + async def async_remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + await asyncio.gather( + *[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)] + ) + + async def __aenter__(self) -> 'AsyncZeroconf': + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + await self.async_close() + return None diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 61119747..bdca1c0d 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -19,385 +19,17 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ -import asyncio -import contextlib -import queue -import threading -from types import TracebackType # noqa # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Type, Union - -from . import ( - DNSOutgoing, - IPVersion, - InterfaceChoice, - InterfacesType, - NonUniqueNameException, - NotifyListener, - ServiceInfo, - Zeroconf, - _BROWSER_TIME, - _CHECK_TIME, - _LISTENER_TIME, - _MDNS_PORT, - _REGISTER_TIME, - _ServiceBrowserBase, - _UNREGISTER_TIME, - current_time_millis, - instance_name_from_service_info, - millis_to_seconds, -) - - -def _get_best_available_queue() -> queue.Queue: - """Create the best available queue type.""" - if hasattr(queue, "SimpleQueue"): - return queue.SimpleQueue() # type: ignore # pylint: disable=all - return queue.Queue() - - -# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed -async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: - """Wait for a condition or timeout.""" - loop = asyncio.get_event_loop() - future = loop.create_future() - - def _handle_timeout() -> None: - if not future.done(): - future.set_result(None) - - timer_handle = loop.call_later(timeout, _handle_timeout) - condition_wait = loop.create_task(condition.wait()) - - def _handle_wait_complete(_: asyncio.Task) -> None: - if not future.done(): - future.set_result(None) - - condition_wait.add_done_callback(_handle_wait_complete) - - try: - await future - finally: - timer_handle.cancel() - if not condition_wait.done(): - condition_wait.cancel() - with contextlib.suppress(asyncio.CancelledError): - await condition_wait - - -class _AsyncSender(threading.Thread): - """A thread to handle sending DNSOutgoing for asyncio.""" - - def __init__(self, zc: 'Zeroconf'): - """Create the sender thread.""" - super().__init__() - self.zc = zc - self.queue = _get_best_available_queue() - self.start() - self.name = "AsyncZeroconfSender" - - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: - """Queue a send to be processed by the thread.""" - self.queue.put((out, addr, port)) - - def close(self) -> None: - """Close the instance.""" - self.queue.put(None) - self.join() - - def run(self) -> None: - """Runner that processes sends FIFO.""" - while True: - event = self.queue.get() - if event is None: - return - self.zc.send(*event) - - -class AsyncNotifyListener(NotifyListener): - """A NotifyListener that async code can use to wait for events.""" - - def __init__(self, aiozc: 'AsyncZeroconf') -> None: - """Create an event for async listeners to wait for.""" - self.aiozc = aiozc - self.loop = asyncio.get_event_loop() - - def notify_all(self) -> None: - """Schedule an async_notify_all.""" - self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all()) - - async def _async_notify_all(self) -> None: - """Notify all async listeners.""" - async with self.aiozc.condition: - self.aiozc.condition.notify_all() - - -class AsyncServiceListener: - def add_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def remove_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def update_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - -class AsyncServiceInfo(ServiceInfo): - """An async version of ServiceInfo.""" - - async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - if self.load_from_cache(aiozc.zeroconf): - return True - - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - try: - aiozc.zeroconf.add_listener(self, None) - while not self._is_complete: - if last <= now: - return False - if next_ <= now: - out = self.generate_request_query(aiozc.zeroconf, now) - if not out.questions: - return self.load_from_cache(aiozc.zeroconf) - aiozc.sender.send(out) - next_ = now + delay - delay *= 2 - - await aiozc.async_wait(min(next_, last) - now) - now = current_time_millis() - finally: - aiozc.zeroconf.remove_listener(self) - - return True +import logging -class AsyncServiceBrowser(_ServiceBrowserBase): - """Used to browse for a service of a specific type. +from .aio import AsyncZeroconf # pylint: disable=unused-import # noqa - The listener object will have its add_service() and - remove_service() methods called when this browser - discovers changes in the services availability.""" +log = logging.getLogger(__name__) - def __init__( - self, - aiozc: 'AsyncZeroconf', - type_: Union[str, list], - handlers: Optional[Union[AsyncServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[AsyncServiceListener] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - self.aiozc = aiozc - super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore - self._browser_task = asyncio.ensure_future(self.async_run()) +# The asyncio module would shadow system asyncio in some import cases +# to resolve this, the module has been renamed zeroconf.aio - async def async_cancel(self) -> None: - """Cancel the browser.""" - self.cancel() - self._browser_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._browser_task - - async def async_run(self) -> None: - """Run the browser task.""" - self.run() - while True: - timeout = self._seconds_to_wait() - if timeout: - async with self.aiozc.condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - await wait_condition_or_timeout(self.aiozc.condition, timeout) - - out = self.generate_ready_queries() - if out: - self.aiozc.sender.send(out, addr=self.addr, port=self.port) - - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.aiozc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) - - -class AsyncZeroconf: - """Implementation of Zeroconf Multicast DNS Service Discovery - - Supports registration, unregistration, queries and browsing. - - The async version is currently a wrapper around the sync version - with I/O being done in the executor for backwards compatibility. - """ - - def __init__( - self, - interfaces: InterfacesType = InterfaceChoice.All, - unicast: bool = False, - ip_version: Optional[IPVersion] = None, - apple_p2p: bool = False, - zc: Optional[Zeroconf] = None, - ) -> None: - """Creates an instance of the Zeroconf class, establishing - multicast communications, listening and reaping threads. - - :param interfaces: :class:`InterfaceChoice` or a list of IP addresses - (IPv4 and IPv6) and interface indexes (IPv6 only). - - IPv6 notes for non-POSIX systems: - * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` - on Python versions before 3.8. - - Also listening on loopback (``::1``) doesn't work, use a real address. - :param ip_version: IP versions to support. If `choice` is a list, the default is detected - from it. Otherwise defaults to V4 only for backward compatibility. - :param apple_p2p: use AWDL interface (only macOS) - """ - self.zeroconf = zc or Zeroconf( - interfaces=interfaces, - unicast=unicast, - ip_version=ip_version, - apple_p2p=apple_p2p, - ) - self.loop = asyncio.get_event_loop() - self.async_notify = AsyncNotifyListener(self) - self.zeroconf.add_notify_listener(self.async_notify) - self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} - self.sender = _AsyncSender(self.zeroconf) - self.condition = asyncio.Condition() - - async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: - """Send a broadcasts to announce a service at intervals.""" - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(interval)) - self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl)) - - async def async_register_service( - self, - info: ServiceInfo, - cooperating_responders: bool = False, - ) -> Awaitable: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service. The name of the service may be changed if needed to make - it unique on the network. Additionally multiple cooperating responders - can register the same service on the network for resilience - (if you want this behavior set `cooperating_responders` to `True`). - - The service will be broadcast in a task. This task is returned - and therefore can be awaited if necessary. - """ - await self.async_check_service(info, cooperating_responders) - self.zeroconf.registry.add(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - - async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: - """Checks the network for a unique service name.""" - instance_name_from_service_info(info) - if cooperating_responders: - return - self._raise_on_name_conflict(info) - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) - self.sender.send(self.zeroconf.generate_service_query(info)) - self._raise_on_name_conflict(info) - - def _raise_on_name_conflict(self, info: ServiceInfo) -> None: - """Raise NonUniqueNameException if the ServiceInfo has a conflict.""" - if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): - raise NonUniqueNameException - - async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: - """Unregister a service. - - The service will be broadcast in a task. This task is returned - and therefore can be awaited if necessary. - """ - self.zeroconf.registry.remove(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) - - async def async_update_service(self, info: ServiceInfo) -> Awaitable: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service. - - The service will be broadcast in a task. This task is returned - and therefore can be awaited if necessary. - """ - self.zeroconf.registry.update(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - - def _close(self) -> None: - """Shutdown zeroconf and the sender.""" - self.sender.close() - self.zeroconf.remove_notify_listener(self.async_notify) - self.zeroconf.close() - - async def async_close(self) -> None: - """Ends the background threads, and prevent this instance from - servicing further queries.""" - await self.async_remove_all_service_listeners() - await self.loop.run_in_executor(None, self._close) - - async def async_get_service_info( - self, type_: str, name: str, timeout: int = 3000 - ) -> Optional[AsyncServiceInfo]: - """Returns network's service information for a particular - name and type, or None if no service matches by the timeout, - which defaults to 3 seconds.""" - info = AsyncServiceInfo(type_, name) - if await info.async_request(self, timeout): - return info - return None - - async def async_wait(self, timeout: float) -> None: - """Calling task waits for a given number of milliseconds or until notified.""" - async with self.condition: - await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout)) - - async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: - """Adds a listener for a particular service type. This object - will then have its add_service and remove_service methods called when - services of that type become available and unavailable.""" - await self.async_remove_service_listener(listener) - self.async_browsers[listener] = AsyncServiceBrowser(self, type_, listener) - - async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None: - """Removes a listener from the set that is currently listening.""" - if listener in self.async_browsers: - await self.async_browsers[listener].async_cancel() - del self.async_browsers[listener] - - async def async_remove_all_service_listeners(self) -> None: - """Removes a listener from the set that is currently listening.""" - await asyncio.gather( - *[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)] - ) - - async def __aenter__(self) -> 'AsyncZeroconf': - return self - - async def __aexit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - await self.async_close() - return None +log.warning( + "zeroconf.asyncio namespace has changed to zeroconf.aio; " + "This compatibility module will be removed in the next version" +) diff --git a/zeroconf/test_aio.py b/zeroconf/test_aio.py new file mode 100644 index 00000000..b05d88b5 --- /dev/null +++ b/zeroconf/test_aio.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for aio.py.""" + +import asyncio +import socket +import unittest.mock + +import pytest + +from . import ( + BadTypeInNameException, + NonUniqueNameException, + ServiceInfo, + ServiceListener, + ServiceNameAlreadyRegistered, + Zeroconf, + _LISTENER_TIME, + current_time_millis, +) +from .aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf + + +@pytest.mark.asyncio +async def test_async_basic_usage() -> None: + """Test we can create and close the instance.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_with_sync_passed_in() -> None: + """Test we can create and close the instance when passing in a sync Zeroconf.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(zc=zc) + assert aiozc.zeroconf is zc + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_registration() -> None: + """Test registering services broadcasts the registration by default.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_name_conflict() -> None: + """Test registering services throws on name conflict.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc2-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + + with pytest.raises(NonUniqueNameException): + task = await aiozc.async_register_service(info) + await task + + with pytest.raises(ServiceNameAlreadyRegistered): + task = await aiozc.async_register_service(info, cooperating_responders=True) + await task + + conflicting_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-3.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + + with pytest.raises(NonUniqueNameException): + task = await aiozc.async_register_service(conflicting_info) + await task + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_registration_name_does_not_match_type() -> None: + """Test registering services throws when the name does not match the type.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc3-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info.type = "_wrong._tcp.local." + with pytest.raises(BadTypeInNameException): + task = await aiozc.async_register_service(info) + await task + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_tasks() -> None: + """Test awaiting broadcast tasks""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc4-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + assert isinstance(task, asyncio.Task) + await task + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + assert isinstance(task, asyncio.Task) + await task + + task = await aiozc.async_unregister_service(new_info) + assert isinstance(task, asyncio.Task) + await task + + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_wait_unblocks_on_update() -> None: + """Test async_wait will unblock on update.""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test-srvc4-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + + # Should unblock due to update from the + # registration + now = current_time_millis() + await aiozc.async_wait(50000) + assert current_time_millis() - now < 3000 + await task + + now = current_time_millis() + await aiozc.async_wait(50) + assert current_time_millis() - now < 1000 + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_info_async_request() -> None: + """Test registering services broadcasts and query with AsyncServceInfo.async_request.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "abc" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name2, type_) + + # Start a tasks BEFORE the registration that will keep trying + # and see the registration a bit later + get_service_info_task1 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + await asyncio.sleep(_LISTENER_TIME / 1000 / 2) + get_service_info_task2 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-1.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-5.local.", + addresses=[socket.inet_aton("10.0.1.5")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + aiosinfo = await get_service_info_task1 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await get_service_info_task2 + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] + + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3"), socket.inet_pton(socket.AF_INET6, "6001:db8::1")], + ) + + task = await aiozc.async_update_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + aiosinfos = await asyncio.gather( + aiozc.async_get_service_info(type_, registration_name), + aiozc.async_get_service_info(type_, registration_name2), + ) + assert aiosinfos[0] is not None + assert aiosinfos[0].addresses == [socket.inet_aton("10.0.1.3")] + assert aiosinfos[1] is not None + assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")] + + aiosinfo = AsyncServiceInfo(type_, registration_name) + zc_cache = aiozc.zeroconf.cache + for name in zc_cache.names(): + for record in zc_cache.entries_with_name(name): + zc_cache.remove(record) + # Generating the race condition is almost impossible + # without patching since its a TOCTOU race + with unittest.mock.patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): + await aiosinfo.async_request(aiozc, 3000) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + + task = await aiozc.async_unregister_service(new_info) + await task + + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is None + + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_service_browser() -> None: + """Test AsyncServiceBrowser.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test9-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + calls = [] + + with pytest.raises(NotImplementedError): + AsyncServiceListener().add_service(aiozc, "_type", "name._type") + + with pytest.raises(NotImplementedError): + AsyncServiceListener().remove_service(aiozc, "_type", "name._type") + + with pytest.raises(NotImplementedError): + AsyncServiceListener().update_service(aiozc, "_type", "name._type") + + class MyListener(AsyncServiceListener): + def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + await aiozc.async_add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + new_info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.3")], + ) + task = await aiozc.async_update_service(new_info) + await task + task = await aiozc.async_unregister_service(new_info) + await task + await aiozc.async_wait(1) + await aiozc.async_close() + + assert calls == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ('remove', type_, registration_name), + ] + + +@pytest.mark.asyncio +async def test_async_context_manager() -> None: + """Test using an async context manager.""" + type_ = "_test10-sr-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc: + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await aiozc.async_register_service(info) + await task + aiosinfo = await aiozc.async_get_service_info(type_, registration_name) + assert aiosinfo is not None diff --git a/zeroconf/test_asyncio.py b/zeroconf/test_asyncio.py index ddf3fbee..28e3a327 100644 --- a/zeroconf/test_asyncio.py +++ b/zeroconf/test_asyncio.py @@ -2,25 +2,12 @@ # -*- coding: utf-8 -*- -"""Unit tests for async.py.""" +"""Unit tests for asyncio.py.""" -import asyncio -import socket -import unittest.mock import pytest -from . import ( - BadTypeInNameException, - NonUniqueNameException, - ServiceInfo, - ServiceListener, - ServiceNameAlreadyRegistered, - Zeroconf, - _LISTENER_TIME, - current_time_millis, -) -from .asyncio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf +from .asyncio import AsyncZeroconf @pytest.mark.asyncio @@ -28,442 +15,3 @@ async def test_async_basic_usage() -> None: """Test we can create and close the instance.""" aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_async_with_sync_passed_in() -> None: - """Test we can create and close the instance when passing in a sync Zeroconf.""" - zc = Zeroconf(interfaces=['127.0.0.1']) - aiozc = AsyncZeroconf(zc=zc) - assert aiozc.zeroconf is zc - await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_async_service_registration() -> None: - """Test registering services broadcasts the registration by default.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test1-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - calls = [] - - class MyListener(ServiceListener): - def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("add", type, name)) - - def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("remove", type, name)) - - def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("update", type, name)) - - listener = MyListener() - aiozc.zeroconf.add_service_listener(type_, listener) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - await task - new_info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.3")], - ) - task = await aiozc.async_update_service(new_info) - await task - task = await aiozc.async_unregister_service(new_info) - await task - await aiozc.async_close() - - assert calls == [ - ('add', type_, registration_name), - ('update', type_, registration_name), - ('remove', type_, registration_name), - ] - - -@pytest.mark.asyncio -async def test_async_service_registration_name_conflict() -> None: - """Test registering services throws on name conflict.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc2-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - await task - - with pytest.raises(NonUniqueNameException): - task = await aiozc.async_register_service(info) - await task - - with pytest.raises(ServiceNameAlreadyRegistered): - task = await aiozc.async_register_service(info, cooperating_responders=True) - await task - - conflicting_info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-3.local.", - addresses=[socket.inet_aton("10.0.1.3")], - ) - - with pytest.raises(NonUniqueNameException): - task = await aiozc.async_register_service(conflicting_info) - await task - - await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_async_service_registration_name_does_not_match_type() -> None: - """Test registering services throws when the name does not match the type.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc3-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - info.type = "_wrong._tcp.local." - with pytest.raises(BadTypeInNameException): - task = await aiozc.async_register_service(info) - await task - await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_async_tasks() -> None: - """Test awaiting broadcast tasks""" - - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc4-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - calls = [] - - class MyListener(ServiceListener): - def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("add", type, name)) - - def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("remove", type, name)) - - def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: - calls.append(("update", type, name)) - - listener = MyListener() - aiozc.zeroconf.add_service_listener(type_, listener) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - assert isinstance(task, asyncio.Task) - await task - - new_info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.3")], - ) - task = await aiozc.async_update_service(new_info) - assert isinstance(task, asyncio.Task) - await task - - task = await aiozc.async_unregister_service(new_info) - assert isinstance(task, asyncio.Task) - await task - - await aiozc.async_close() - - assert calls == [ - ('add', type_, registration_name), - ('update', type_, registration_name), - ('remove', type_, registration_name), - ] - - -@pytest.mark.asyncio -async def test_async_wait_unblocks_on_update() -> None: - """Test async_wait will unblock on update.""" - - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test-srvc4-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - - # Should unblock due to update from the - # registration - now = current_time_millis() - await aiozc.async_wait(50000) - assert current_time_millis() - now < 3000 - await task - - now = current_time_millis() - await aiozc.async_wait(50) - assert current_time_millis() - now < 1000 - - await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_service_info_async_request() -> None: - """Test registering services broadcasts and query with AsyncServceInfo.async_request.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test1-srvc-type._tcp.local." - name = "xxxyyy" - name2 = "abc" - registration_name = "%s.%s" % (name, type_) - registration_name2 = "%s.%s" % (name2, type_) - - # Start a tasks BEFORE the registration that will keep trying - # and see the registration a bit later - get_service_info_task1 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) - await asyncio.sleep(_LISTENER_TIME / 1000 / 2) - get_service_info_task2 = asyncio.ensure_future(aiozc.async_get_service_info(type_, registration_name)) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-1.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - info2 = ServiceInfo( - type_, - registration_name2, - 80, - 0, - 0, - desc, - "ash-5.local.", - addresses=[socket.inet_aton("10.0.1.5")], - ) - tasks = [] - tasks.append(await aiozc.async_register_service(info)) - tasks.append(await aiozc.async_register_service(info2)) - await asyncio.gather(*tasks) - - aiosinfo = await get_service_info_task1 - assert aiosinfo is not None - assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] - - aiosinfo = await get_service_info_task2 - assert aiosinfo is not None - assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] - - aiosinfo = await aiozc.async_get_service_info(type_, registration_name) - assert aiosinfo is not None - assert aiosinfo.addresses == [socket.inet_aton("10.0.1.2")] - - new_info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.3"), socket.inet_pton(socket.AF_INET6, "6001:db8::1")], - ) - - task = await aiozc.async_update_service(new_info) - await task - - aiosinfo = await aiozc.async_get_service_info(type_, registration_name) - assert aiosinfo is not None - assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] - - aiosinfos = await asyncio.gather( - aiozc.async_get_service_info(type_, registration_name), - aiozc.async_get_service_info(type_, registration_name2), - ) - assert aiosinfos[0] is not None - assert aiosinfos[0].addresses == [socket.inet_aton("10.0.1.3")] - assert aiosinfos[1] is not None - assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")] - - aiosinfo = AsyncServiceInfo(type_, registration_name) - zc_cache = aiozc.zeroconf.cache - for name in zc_cache.names(): - for record in zc_cache.entries_with_name(name): - zc_cache.remove(record) - # Generating the race condition is almost impossible - # without patching since its a TOCTOU race - with unittest.mock.patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): - await aiosinfo.async_request(aiozc, 3000) - assert aiosinfo is not None - assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] - - task = await aiozc.async_unregister_service(new_info) - await task - - aiosinfo = await aiozc.async_get_service_info(type_, registration_name) - assert aiosinfo is None - - await aiozc.async_close() - - -@pytest.mark.asyncio -async def test_async_service_browser() -> None: - """Test AsyncServiceBrowser.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - type_ = "_test9-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - calls = [] - - with pytest.raises(NotImplementedError): - AsyncServiceListener().add_service(aiozc, "_type", "name._type") - - with pytest.raises(NotImplementedError): - AsyncServiceListener().remove_service(aiozc, "_type", "name._type") - - with pytest.raises(NotImplementedError): - AsyncServiceListener().update_service(aiozc, "_type", "name._type") - - class MyListener(AsyncServiceListener): - def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: - calls.append(("add", type, name)) - - def remove_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: - calls.append(("remove", type, name)) - - def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: - calls.append(("update", type, name)) - - listener = MyListener() - await aiozc.async_add_service_listener(type_, listener) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - await task - new_info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.3")], - ) - task = await aiozc.async_update_service(new_info) - await task - task = await aiozc.async_unregister_service(new_info) - await task - await aiozc.async_wait(1) - await aiozc.async_close() - - assert calls == [ - ('add', type_, registration_name), - ('update', type_, registration_name), - ('remove', type_, registration_name), - ] - - -@pytest.mark.asyncio -async def test_async_context_manager() -> None: - """Test using an async context manager.""" - type_ = "_test10-sr-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc: - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - task = await aiozc.async_register_service(info) - await task - aiosinfo = await aiozc.async_get_service_info(type_, registration_name) - assert aiosinfo is not None From 26b70050ffe7dee4fb34428f285be377d1d8f210 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 11:43:27 -1000 Subject: [PATCH 208/608] Update changelog for zeroconf.asyncio -> zeroconf.aio (#506) --- README.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.rst b/README.rst index 10b25f14..b83fc65f 100644 --- a/README.rst +++ b/README.rst @@ -134,9 +134,29 @@ See examples directory for more. Changelog ========= +0.33.0 (Unreleased) +=================== + +* Breaking change: zeroconf.asyncio has been removed in favor of zeroconf.aio - TBD + + The asyncio name could shadow system asyncio in some cases. If + zeroconf is in sys.path, this would result in loading zeroconf.asyncio + when system asyncio was intended. + 0.32.0 (Unreleased) =================== +* Breaking change: zeroconf.asyncio has been renamed zeroconf.aio (#503) @bdraco + + The asyncio name could shadow system asyncio in some cases. If + zeroconf is in sys.path, this would result in loading zeroconf.asyncio + when system asyncio was intended. + + An `zeroconf.asyncio` shim module has been added that imports `zeroconf.aio` + that was available in 0.31 to provide backwards compatibility in 0.32.0 + This module will be removed in 0.33.0 to fix the underlying problem + detailed in #502 + * Breaking change: Update internal version check to match docs (3.6+) (#491) @bdraco Python version eariler then 3.6 were likely broken with zeroconf From 1cfcc5636a845924eb683ad4acf4d9a36ef85fb7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 12:18:41 -1000 Subject: [PATCH 209/608] Extract code for handling queries into QueryHandler (#507) --- zeroconf/__init__.py | 161 +++++++++++++++++++++++-------------------- 1 file changed, 88 insertions(+), 73 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index c362d16b..0b0ca19b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2717,6 +2717,91 @@ def _remove(self, info: ServiceInfo) -> None: del self.services[lower_name] +class QueryHandler: + """Query the ServiceRegistry.""" + + def __init__(self, registry: ServiceRegistry): + """Init the query handler.""" + self.registry = registry + + def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in self.registry.get_types(): + out.add_answer( + msg, + DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, + ), + ) + + def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a PTR query.""" + for service in self.registry.get_infos_type(question.name.lower()): + out.add_answer(msg, service.dns_pointer()) + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + out.add_additional_answer(service.dns_service()) + out.add_additional_answer(service.dns_text()) + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a query any query other then PTR. + + Add answer(s) for A, AAAA, SRV, or TXT queries. + """ + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): + for dns_address in service.dns_addresses(): + out.add_answer(msg, dns_address) + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + return + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer(msg, service.dns_service()) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer(msg, service.dns_text()) + if question.type == _TYPE_SRV: + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: + """Deal with incoming query packets. Provides a response if possible.""" + if unicast: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) + for question in msg.questions: + out.add_question(question) + else: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + + for question in msg.questions: + if question.type == _TYPE_PTR: + if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._answer_service_type_enumeration_query(msg, out) + else: + self._answer_ptr_query(msg, out, question) + continue + + self._answer_non_ptr_query(msg, out, question) + + if out is not None and out.answers: + out.id = msg.id + return out + + return None + + class Zeroconf(QuietLogger): """Implementation of Zeroconf Multicast DNS Service Discovery @@ -2777,6 +2862,7 @@ def __init__( self._notify_listeners = [] # type: List[NotifyListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] self.registry = ServiceRegistry() + self.query_handler = QueryHandler(self.registry) self.cache = DNSCache() @@ -3091,82 +3177,11 @@ def handle_response(self, msg: DNSIncoming) -> None: # because the data was not yet populated. self.cache.remove_records(removes) - def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: - """Provide an answer to a service type enumeration query. - - https://datatracker.ietf.org/doc/html/rfc6763#section-9 - """ - for stype in self.registry.get_types(): - out.add_answer( - msg, - DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - - def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name.lower()): - out.add_answer(msg, service.dns_pointer()) - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer(service.dns_service()) - out.add_additional_answer(service.dns_text()) - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a query any query other then PTR. - - Add answer(s) for A, AAAA, SRV, or TXT queries. - """ - name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(name_to_find) # type: ignore - if service is None: - return - - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if possible.""" - # Support unicast client responses - # - if port != _MDNS_PORT: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) - for question in msg.questions: - out.add_question(question) - else: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._answer_service_type_enumeration_query(msg, out) - else: - self._answer_ptr_query(msg, out, question) - continue - - self._answer_non_ptr_query(msg, out, question) - - if out is not None and out.answers: - out.id = msg.id + out = self.query_handler.response(msg, port != _MDNS_PORT) + if out: self.send(out, addr, port) def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: From db866f7d032ed031e6aa5e14fba24b3dafeafa8d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 12:34:00 -1000 Subject: [PATCH 210/608] Stop monkey patching send in the PTR optimization test (#509) --- zeroconf/test.py | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 060b80a0..00e03302 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2417,33 +2417,9 @@ def test_ptr_optimization(): type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zc.send - nbr_answers = nbr_additionals = nbr_authorities = 0 has_srv = has_txt = has_a = False - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal nbr_answers, nbr_additionals, nbr_authorities - nonlocal has_srv, has_txt, has_a - - nbr_answers += len(out.answers) - nbr_authorities += len(out.authorities) - for answer in out.additionals: - nbr_additionals += 1 - if answer.type == r._TYPE_SRV: - has_srv = True - elif answer.type == r._TYPE_TXT: - has_txt = True - elif answer.type == r._TYPE_A: - has_a = True - - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zc, "send", send) - # register zc.register_service(info) nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -2451,7 +2427,18 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): # query query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) + out = zc.query_handler.response(r.DNSIncoming(query.packet()), False) + assert out is not None + nbr_answers += len(out.answers) + nbr_authorities += len(out.authorities) + for answer in out.additionals: + nbr_additionals += 1 + if answer.type == r._TYPE_SRV: + has_srv = True + elif answer.type == r._TYPE_TXT: + has_txt = True + elif answer.type == r._TYPE_A: + has_a = True assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 assert has_srv and has_txt and has_a From 954ca3fb498bdc7cd5a6a168c40ad5b6b2476e71 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 13:08:15 -1000 Subject: [PATCH 211/608] Stop monkey patching send in the TTL test (#510) --- zeroconf/test.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 00e03302..817557dd 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1009,9 +1009,6 @@ def test_ttl(self): addresses=[socket.inet_aton("10.0.1.2")], ) - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zc.send - nbr_answers = nbr_additionals = nbr_authorities = 0 def get_ttl(record_type): @@ -1022,7 +1019,7 @@ def get_ttl(record_type): else: return r._DNS_OTHER_TTL - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def _process_outgoing_packet(out): """Sends an outgoing packet.""" nonlocal nbr_answers, nbr_additionals, nbr_authorities @@ -1035,14 +1032,14 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): for answer in out.authorities: nbr_authorities += 1 assert answer.ttl == get_ttl(answer.type) - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zc, "send", send) # register service with default TTL expected_ttl = None - zc.register_service(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.add(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, None)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -1053,36 +1050,46 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister expected_ttl = 0 - zc.unregister_service(info) + zc.registry.remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 + expected_ttl = None + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.add(info) # register service with custom TTL expected_ttl = r._DNS_HOST_TTL * 2 assert expected_ttl != r._DNS_HOST_TTL - zc.register_service(info, ttl=expected_ttl) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 nbr_answers = nbr_additionals = nbr_authorities = 0 # query + expected_ttl = None query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) - zc.handle_query(r.DNSIncoming(query.packet()), r._MDNS_ADDR, r._MDNS_PORT) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister expected_ttl = 0 - zc.unregister_service(info) + zc.registry.remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 zc.close() From 70b455ba53ce43e9280c02612e8a89665abd57f6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 14:59:02 -1000 Subject: [PATCH 212/608] Remove uneeded wait in the Engine thread (#511) - It is not longer necessary to wait since the socketpair was added in #243 which will cause the select to unblock when a new socket is added or removed. --- zeroconf/__init__.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0b0ca19b..6474e51e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -1440,17 +1440,8 @@ def __init__(self, zc: 'Zeroconf') -> None: def run(self) -> None: while not self.zc.done: - rs = list(self.readers.keys()) - if not rs: - # No sockets to manage, but we wait for the timeout - # or addition of a socket - with self.condition: - self.condition.wait(self.timeout) - continue - try: - rs.append(self.socketpair[0]) - rr, _wr, _er = select.select(rs, [], [], self.timeout) + rr, _wr, _er = select.select([*self.readers.keys(), self.socketpair[0]], [], [], self.timeout) if self.zc.done: return From 9a766a2a96abd0f105056839b5c30f2ede31ea2e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 17:19:55 -1000 Subject: [PATCH 213/608] Break out record updating into RecordManager (#512) --- zeroconf/__init__.py | 179 +++++++++++++++++++++++-------------------- 1 file changed, 98 insertions(+), 81 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 6474e51e..71e07182 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -35,9 +35,8 @@ import time import warnings from collections import OrderedDict -from contextlib import contextmanager from types import TracebackType # noqa # used in type hints -from typing import Dict, Generator, Iterable, List, Optional, Type, Union, cast +from typing import Dict, Iterable, List, Optional, Type, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints import ifaddr @@ -1464,8 +1463,8 @@ def run(self) -> None: now = current_time_millis() if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms: self._last_cache_cleanup = now - with self.zc.update_records(now, list(self.zc.cache.expire(now))): - pass + self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) + self.zc.record_manager.updates_complete() self.socketpair[0].close() self.socketpair[1].close() @@ -2793,6 +2792,99 @@ def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: return None +class RecordManager: + """Process records into the cache and notify listeners.""" + + def __init__(self, zeroconf: 'Zeroconf'): + """Init the record manager.""" + self.zc = zeroconf + self.cache = zeroconf.cache + + def updates(self, now: float, rec: List[DNSRecord]) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called before the cache is updated. + """ + for listener in self.zc.listeners: + listener.update_records(self.zc, now, rec) + + def updates_complete(self) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called after the cache is updated. + """ + for listener in self.zc.listeners: + listener.update_records_complete() + self.zc.notify_all() + + def updates_from_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + updates: List[DNSRecord] = [] + address_adds: List[DNSAddress] = [] + other_adds: List[DNSRecord] = [] + removes: List[DNSRecord] = [] + now = current_time_millis() + for record in msg.answers: + + updated = True + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): + if entry == record: + updated = False + if record.created - entry.created > 1000 and entry not in msg.answers: + removes.append(entry) + + expired = record.is_expired(now) + maybe_entry = self.cache.get(record) + if not expired: + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + if isinstance(record, DNSAddress): + address_adds.append(record) + else: + other_adds.append(record) + if updated: + updates.append(record) + elif maybe_entry is not None: + updates.append(record) + removes.append(record) + + if not updates and not address_adds and not other_adds and not removes: + return + + self.updates(now, updates) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the context manager to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + self.cache.add_records(itertools.chain(address_adds, other_adds)) + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + self.cache.remove_records(removes) + self.updates_complete() + + class Zeroconf(QuietLogger): """Implementation of Zeroconf Multicast DNS Service Discovery @@ -2854,8 +2946,8 @@ def __init__( self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] self.registry = ServiceRegistry() self.query_handler = QueryHandler(self.registry) - self.cache = DNSCache() + self.record_manager = RecordManager(self) self.condition = threading.Condition() @@ -3088,85 +3180,10 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions log.exception('Unknown error, possibly benign: %r', e) - @contextmanager - def update_records(self, now: float, rec: List[DNSRecord]) -> Generator: - """Used to notify listeners of new information that has updated - a record. - - This method must be called before the cache is updated. - """ - try: - for listener in self.listeners: - listener.update_records(self, now, rec) - yield - finally: - for listener in self.listeners: - listener.update_records_complete() - self.notify_all() - def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" - updates = [] # type: List[DNSRecord] - address_adds = [] # type: List[DNSAddress] - other_adds = [] # type: List[DNSRecord] - removes = [] # type: List[DNSRecord] - now = current_time_millis() - for record in msg.answers: - - updated = True - - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False - if record.created - entry.created > 1000 and entry not in msg.answers: - removes.append(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) - else: - if isinstance(record, DNSAddress): - address_adds.append(record) - else: - other_adds.append(record) - if updated: - updates.append(record) - elif maybe_entry is not None: - updates.append(record) - removes.append(record) - - if not updates and not address_adds and not other_adds and not removes: - return - - with self.update_records(now, updates): - # The cache adds must be processed AFTER we trigger - # the updates since we compare existing data - # with the new data and updating the cache - # ahead of update_record will cause listeners - # to miss changes - # - # We must process address adds before non-addresses - # otherwise a fetch of ServiceInfo may miss an address - # because it thinks the cache is complete - # - # The cache is processed under the context manager to ensure - # that any ServiceBrowser that is going to call - # zc.get_service_info will see the cached value - # but ONLY after all the record updates have been - # processsed. - self.cache.add_records(itertools.chain(address_adds, other_adds)) - # Removes are processed last since - # ServiceInfo could generate an un-needed query - # because the data was not yet populated. - self.cache.remove_records(removes) + self.record_manager.updates_from_response(msg) def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if From 3d6c68278713a2ca66e27938feedcc451a078369 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 17:22:04 -1000 Subject: [PATCH 214/608] Update changelog (#513) --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index b83fc65f..f37b6de5 100644 --- a/README.rst +++ b/README.rst @@ -189,6 +189,12 @@ Changelog The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. +* Break out record updating into RecordManager (#512) @bdraco + +* Remove uneeded wait in the Engine thread (#511) @bdraco + +* Extract code for handling queries into QueryHandler (#507) @bdraco + * Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco * Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco From 6cc3adb020115ef9626caf61bb5f7550a2da8b4c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 19:02:29 -1000 Subject: [PATCH 215/608] Move RecordUpdateListener management into RecordManager (#514) --- zeroconf/__init__.py | 60 ++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 71e07182..dd047cca 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2799,6 +2799,7 @@ def __init__(self, zeroconf: 'Zeroconf'): """Init the record manager.""" self.zc = zeroconf self.cache = zeroconf.cache + self.listeners: List[RecordUpdateListener] = [] def updates(self, now: float, rec: List[DNSRecord]) -> None: """Used to notify listeners of new information that has updated @@ -2806,7 +2807,7 @@ def updates(self, now: float, rec: List[DNSRecord]) -> None: This method must be called before the cache is updated. """ - for listener in self.zc.listeners: + for listener in self.listeners: listener.update_records(self.zc, now, rec) def updates_complete(self) -> None: @@ -2815,7 +2816,7 @@ def updates_complete(self) -> None: This method must be called after the cache is updated. """ - for listener in self.zc.listeners: + for listener in self.listeners: listener.update_records_complete() self.zc.notify_all() @@ -2884,6 +2885,35 @@ def updates_from_response(self, msg: DNSIncoming) -> None: self.cache.remove_records(removes) self.updates_complete() + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s).""" + now = current_time_millis() + self.listeners.append(listener) + records = [] + if question is not None: + questions = [question] if isinstance(question, DNSQuestion) else question + for single_question in questions: + for record in self.cache.entries_with_name(single_question.name): + if single_question.answered_by(record) and not record.is_expired(now): + records.append(record) + + if records: + listener.update_records(self.zc, now, records) + listener.update_records_complete() + self.zc.notify_all() + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener.""" + try: + self.listeners.remove(listener) + self.zc.notify_all() + except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions + log.exception('Unknown error, possibly benign: %r', e) + class Zeroconf(QuietLogger): @@ -2941,7 +2971,6 @@ def __init__( log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) self.multi_socket = unicast or interfaces is not InterfaceChoice.Default - self.listeners = [] # type: List[RecordUpdateListener] self._notify_listeners = [] # type: List[NotifyListener] self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] self.registry = ServiceRegistry() @@ -2967,6 +2996,10 @@ def __init__( def done(self) -> bool: return self._GLOBAL_DONE + @property + def listeners(self) -> List[RecordUpdateListener]: + return self.record_manager.listeners + def wait(self, timeout: float) -> None: """Calling thread waits for a given number of milliseconds or until notified.""" @@ -3157,28 +3190,11 @@ def add_listener( """Adds a listener for a given question. The listener will have its update_record method called when information is available to answer the question(s).""" - now = current_time_millis() - self.listeners.append(listener) - records = [] - if question is not None: - questions = [question] if isinstance(question, DNSQuestion) else question - for single_question in questions: - for record in self.cache.entries_with_name(single_question.name): - if single_question.answered_by(record) and not record.is_expired(now): - records.append(record) - - if records: - listener.update_records(self, now, records) - listener.update_records_complete() - self.notify_all() + self.record_manager.add_listener(listener, question) def remove_listener(self, listener: RecordUpdateListener) -> None: """Removes a listener.""" - try: - self.listeners.remove(listener) - self.notify_all() - except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions - log.exception('Unknown error, possibly benign: %r', e) + self.record_manager.remove_listener(listener) def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers From f80a0515cf73b1e304d0615f8cee91ae38ac1ae8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 21:59:42 -1000 Subject: [PATCH 216/608] Small cleanups to RecordManager.add_listener (#516) --- zeroconf/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dd047cca..8ddfaa7d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2795,7 +2795,7 @@ def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: class RecordManager: """Process records into the cache and notify listeners.""" - def __init__(self, zeroconf: 'Zeroconf'): + def __init__(self, zeroconf: 'Zeroconf') -> None: """Init the record manager.""" self.zc = zeroconf self.cache = zeroconf.cache @@ -2891,19 +2891,20 @@ def add_listener( """Adds a listener for a given question. The listener will have its update_record method called when information is available to answer the question(s).""" - now = current_time_millis() self.listeners.append(listener) - records = [] + if question is not None: + now = current_time_millis() + records = [] questions = [question] if isinstance(question, DNSQuestion) else question for single_question in questions: for record in self.cache.entries_with_name(single_question.name): if single_question.answered_by(record) and not record.is_expired(now): records.append(record) + if records: + listener.update_records(self.zc, now, records) + listener.update_records_complete() - if records: - listener.update_records(self.zc, now, records) - listener.update_records_complete() self.zc.notify_all() def remove_listener(self, listener: RecordUpdateListener) -> None: From e12523933819087d2a087b8388e79b24af058a58 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 22:22:12 -1000 Subject: [PATCH 217/608] Remove broad exception catch from RecordManager.remove_listener (#517) --- zeroconf/__init__.py | 4 ++-- zeroconf/test.py | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 8ddfaa7d..10a13d3c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2912,8 +2912,8 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: try: self.listeners.remove(listener) self.zc.notify_all() - except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions - log.exception('Unknown error, possibly benign: %r', e) + except ValueError as e: + log.exception('Failed to remove listener: %r', e) class Zeroconf(QuietLogger): diff --git a/zeroconf/test.py b/zeroconf/test.py index 817557dd..21118cf8 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2747,7 +2747,9 @@ def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None nonlocal updates updates.append(record) - zc.add_listener(LegacyRecordUpdateListener(), None) + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): @@ -2778,4 +2780,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): assert len(updates) assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + zc.close() From ef7aa250e140d70b8c62abf4d13dcaa36f128c63 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 11 Jun 2021 23:02:03 -1000 Subject: [PATCH 218/608] Add test helper to inject DNSIncoming (#518) --- zeroconf/test.py | 86 +++++++++++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 21118cf8..ad4017cb 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -26,6 +26,7 @@ import zeroconf as r from zeroconf import ( DNSHinfo, + DNSIncoming, DNSText, ServiceBrowser, ServiceInfo, @@ -59,6 +60,11 @@ def teardown_module(): log.setLevel(original_logging_level) +def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: + """Inject a DNSIncoming response.""" + zc.handle_response(msg) + + @lru_cache(maxsize=None) def has_working_ipv6(): """Return True if if the system can bind an IPv6 address.""" @@ -788,7 +794,7 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS try: # service added - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) assert dns_text is not None assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' @@ -805,7 +811,7 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS # service updated. currently only text record can be updated service_text = b'path=/~humingchun/' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) assert dns_text is not None assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' @@ -814,14 +820,14 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS # The split message only has a SRV and A record. # This should not evict TXT records from the cache - zeroconf.handle_response(mock_split_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated)) time.sleep(1.1) dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) assert dns_text is not None assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' # service removed - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) assert dns_text is None @@ -1471,6 +1477,9 @@ def update_service(self, zeroconf, type, name): cached_info.load_from_cache(zeroconf_browser) assert cached_info.properties is not None + # Populate the cache + zeroconf_browser.get_service_info(subtype, registration_name) + # get service info with only the cache cached_info = ServiceInfo(subtype, registration_name) cached_info.load_from_cache(zeroconf_browser) @@ -1649,7 +1658,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi wait_time = 3 # service added - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Added)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) service_add_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 0 @@ -1658,7 +1667,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi # service SRV updated service_updated_event.clear() service_server = 'ash-2.local.' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 1 @@ -1667,7 +1676,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi # service TXT updated service_updated_event.clear() service_text = b'path=/~matt2/' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 2 @@ -1676,7 +1685,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi # service TXT updated - duplicate update should not trigger another service_updated service_updated_event.clear() service_text = b'path=/~matt2/' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 2 @@ -1685,7 +1694,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi # service A updated service_updated_event.clear() service_address = '10.0.1.3' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 3 @@ -1696,14 +1705,14 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi service_server = 'ash-3.local.' service_text = b'path=/~matt3/' service_address = '10.0.1.3' - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Updated)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 4 assert service_removed_count == 0 # service removed - zeroconf.handle_response(mock_incoming_msg(r.ServiceStateChange.Removed)) + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) service_removed_event.wait(wait_time) assert service_added_count == 1 assert service_updated_count == 4 @@ -1933,10 +1942,11 @@ def get_service_info_helper(zc, type, name): # Expext query for SRV, A, AAAA last_sent = None send_event.clear() - zc.handle_response( + _inject_response( + zc, mock_incoming_msg( [r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text)] - ) + ), ) send_event.wait(wait_time) assert last_sent is not None @@ -1949,7 +1959,8 @@ def get_service_info_helper(zc, type, name): # Expext query for A, AAAA last_sent = None send_event.clear() - zc.handle_response( + _inject_response( + zc, mock_incoming_msg( [ r.DNSService( @@ -1963,7 +1974,7 @@ def get_service_info_helper(zc, type, name): service_server, ) ] - ) + ), ) send_event.wait(wait_time) assert last_sent is not None @@ -1976,7 +1987,8 @@ def get_service_info_helper(zc, type, name): # Expext no further queries last_sent = None send_event.clear() - zc.handle_response( + _inject_response( + zc, mock_incoming_msg( [ r.DNSAddress( @@ -1987,7 +1999,7 @@ def get_service_info_helper(zc, type, name): socket.inet_pton(socket.AF_INET, service_address), ) ] - ) + ), ) send_event.wait(wait_time) assert last_sent is None @@ -2059,7 +2071,8 @@ def get_service_info_helper(zc, type, name): # Expext no further queries last_sent = None send_event.clear() - zc.handle_response( + _inject_response( + zc, mock_incoming_msg( [ r.DNSText( @@ -2083,7 +2096,7 @@ def get_service_info_helper(zc, type, name): socket.inet_pton(socket.AF_INET, service_address), ), ] - ) + ), ) send_event.wait(wait_time) assert last_sent is None @@ -2135,11 +2148,13 @@ def mock_incoming_msg( wait_time = 3 # all three services added - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), ) - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), ) called_with_refresh_time_check = False @@ -2153,14 +2168,16 @@ def _mock_get_expiration_time(self, percent): # Set an expire time that will force a refresh with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), ) # Add the last record after updating the first one # to ensure the service_add_event only gets set # after the update - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), ) service_add_event.wait(wait_time) assert called_with_refresh_time_check is True @@ -2168,14 +2185,17 @@ def _mock_get_expiration_time(self, percent): assert service_removed_count == 0 # all three services removed - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), ) - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), ) - zeroconf.handle_response( - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), ) service_removed_event.wait(wait_time) assert service_added_count == 3 From 7ce29a2f736af13886aa66dc1c49e15768e6fdcc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 09:07:34 -1000 Subject: [PATCH 219/608] Make the cache cleanup interval a constant (#522) --- zeroconf/__init__.py | 4 ++-- zeroconf/test.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 10a13d3c..b3a5f451 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -82,6 +82,7 @@ _LISTENER_TIME = 200 # ms _BROWSER_TIME = 1000 # ms _BROWSER_BACKOFF_LIMIT = 3600 # s +_CACHE_CLEANUP_INTERVAL = 10000 # ms # Some DNS constants @@ -1431,7 +1432,6 @@ def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.readers = {} # type: Dict[socket.socket, Listener] self.timeout = 5 - self.cache_cleanup_interval_ms = 10000.0 self.condition = threading.Condition() self.socketpair = socket.socketpair() self._last_cache_cleanup = 0.0 @@ -1461,7 +1461,7 @@ def run(self) -> None: raise now = current_time_millis() - if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms: + if now - self._last_cache_cleanup >= _CACHE_CLEANUP_INTERVAL: self._last_cache_cleanup = now self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) self.zc.record_manager.updates_complete() diff --git a/zeroconf/test.py b/zeroconf/test.py index ad4017cb..e862f5aa 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -1236,6 +1236,7 @@ def test_cache_empty_multiple_calls_does_not_throw(self): class TestReaper(unittest.TestCase): + @unittest.mock.patch.object(r, "_CACHE_CLEANUP_INTERVAL", 10) def test_reaper(self): zeroconf = Zeroconf(interfaces=['127.0.0.1']) cache = zeroconf.cache @@ -1245,7 +1246,6 @@ def test_reaper(self): zeroconf.cache.add(record_with_10s_ttl) zeroconf.cache.add(record_with_1s_ttl) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - zeroconf.engine.cache_cleanup_interval_ms = 10 time.sleep(1) with zeroconf.engine.condition: zeroconf.engine._notify() From b37d115a233b61e2989d1439f65cdd911b86f407 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 09:09:48 -1000 Subject: [PATCH 220/608] Update python compatibility as PyPy3 7.2 is required (#523) - When the version requirement changed to cpython 3.6, PyPy was not bumped as well --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f37b6de5..228f0fac 100644 --- a/README.rst +++ b/README.rst @@ -45,7 +45,7 @@ Python compatibility -------------------- * CPython 3.6+ -* PyPy3 5.8+ +* PyPy3 7.2+ Versioning ---------- From f49342cdaff2d012ad23635b49ae746ad71333df Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 09:34:46 -1000 Subject: [PATCH 221/608] Fix flakey test_update_record (#525) - Ensure enough time has past that the first record update was processed before sending the second one --- zeroconf/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index e862f5aa..44e7fcd7 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2156,7 +2156,6 @@ def mock_incoming_msg( zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), ) - called_with_refresh_time_check = False def _mock_get_expiration_time(self, percent): @@ -2172,6 +2171,7 @@ def _mock_get_expiration_time(self, percent): zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), ) + zeroconf.wait(100) # Add the last record after updating the first one # to ensure the service_add_event only gets set # after the update From 16d40b50ccab6a8d53fe4aeb7b0006f7fd67ef53 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 09:41:01 -1000 Subject: [PATCH 222/608] Move ipversion auto detection code into its own function (#524) --- zeroconf/__init__.py | 29 +++++++++++++++++------------ zeroconf/test.py | 8 ++++++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b3a5f451..7bbae60a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -2617,6 +2617,22 @@ def can_send_to(sock: socket.socket, address: str) -> bool: return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) +def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: + """Auto detect the IP version when it is not provided.""" + if isinstance(interfaces, list): + has_v6 = any( + isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) + for i in interfaces + ) + has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) + if has_v4 and has_v6: + return IPVersion.All + if has_v6: + return IPVersion.V6Only + + return IPVersion.V4Only + + class ServiceRegistry: """A registry to keep track of services. @@ -2945,19 +2961,8 @@ def __init__( from it. Otherwise defaults to V4 only for backward compatibility. :param apple_p2p: use AWDL interface (only macOS) """ - if ip_version is None and isinstance(interfaces, list): - has_v6 = any( - isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) - for i in interfaces - ) - has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) - if has_v4 and has_v6: - ip_version = IPVersion.All - elif has_v6: - ip_version = IPVersion.V6Only - if ip_version is None: - ip_version = IPVersion.V4Only + ip_version = autodetect_ip_version(interfaces) # hook for threads self._GLOBAL_DONE = False diff --git a/zeroconf/test.py b/zeroconf/test.py index 44e7fcd7..448f78da 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2805,3 +2805,11 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.remove_listener(listener) zc.close() + + +def test_autodetect_ip_version(): + """Tests for auto detecting IPVersion based on interface ips.""" + assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only + assert r.autodetect_ip_version([]) is r.IPVersion.V4Only + assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All + assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only From 14542bd2bd327fd9b3d93cfb48a3bf09d6c89e15 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 10:18:10 -1000 Subject: [PATCH 223/608] Fix flakey test_update_record test (round 2) (#528) --- zeroconf/test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zeroconf/test.py b/zeroconf/test.py index 448f78da..f0092ea4 100644 --- a/zeroconf/test.py +++ b/zeroconf/test.py @@ -2156,6 +2156,8 @@ def mock_incoming_msg( zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), ) + zeroconf.wait(100) + called_with_refresh_time_check = False def _mock_get_expiration_time(self, percent): @@ -2171,7 +2173,6 @@ def _mock_get_expiration_time(self, percent): zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), ) - zeroconf.wait(100) # Add the last record after updating the first one # to ensure the service_add_event only gets set # after the update From 3f1a5a7b7a929d5f699812a809347b0c2f799fbf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 10:26:59 -1000 Subject: [PATCH 224/608] Relocate tests to tests directory (#527) --- Makefile | 4 ++-- setup.cfg | 3 +++ tests/__init__.py | 21 +++++++++++++++++++++ {zeroconf => tests}/test_aio.py | 4 ++-- {zeroconf => tests}/test_asyncio.py | 2 +- zeroconf/test.py => tests/test_init.py | 0 6 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 tests/__init__.py rename {zeroconf => tests}/test_aio.py (99%) rename {zeroconf => tests}/test_asyncio.py (87%) rename zeroconf/test.py => tests/test_init.py (100%) diff --git a/Makefile b/Makefile index 66bc85f1..a627c37e 100644 --- a/Makefile +++ b/Makefile @@ -39,10 +39,10 @@ mypy: mypy examples/*.py zeroconf/*.py test: - pytest -v zeroconf/test.py zeroconf/test_aio.py zeroconf/test_asyncio.py + pytest -v tests test_coverage: - pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing zeroconf/test.py zeroconf/test_aio.py zeroconf/test_asyncio.py + pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf diff --git a/setup.cfg b/setup.cfg index d4354ef4..a9dddb26 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ +[tool:pytest] +testpaths = tests + [flake8] show-source = 1 application-import-names=zeroconf diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/test_aio.py b/tests/test_aio.py similarity index 99% rename from zeroconf/test_aio.py rename to tests/test_aio.py index b05d88b5..b50e5bc7 100644 --- a/zeroconf/test_aio.py +++ b/tests/test_aio.py @@ -10,7 +10,7 @@ import pytest -from . import ( +from zeroconf import ( BadTypeInNameException, NonUniqueNameException, ServiceInfo, @@ -20,7 +20,7 @@ _LISTENER_TIME, current_time_millis, ) -from .aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf +from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf @pytest.mark.asyncio diff --git a/zeroconf/test_asyncio.py b/tests/test_asyncio.py similarity index 87% rename from zeroconf/test_asyncio.py rename to tests/test_asyncio.py index 28e3a327..ee8f8053 100644 --- a/zeroconf/test_asyncio.py +++ b/tests/test_asyncio.py @@ -7,7 +7,7 @@ import pytest -from .asyncio import AsyncZeroconf +from zeroconf.asyncio import AsyncZeroconf @pytest.mark.asyncio diff --git a/zeroconf/test.py b/tests/test_init.py similarity index 100% rename from zeroconf/test.py rename to tests/test_init.py From 2d8a27a54aee298af74121986b4ea76f1f50b421 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 10:40:08 -1000 Subject: [PATCH 225/608] Move asyncio utils into zeroconf.utils.aio (#530) --- Makefile | 6 ++-- tests/utils/__init__.py | 21 ++++++++++++ tests/utils/test_aio.py | 22 +++++++++++++ zeroconf/aio.py | 30 +---------------- zeroconf/utils/__init__.py | 21 ++++++++++++ zeroconf/utils/aio.py | 67 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 136 insertions(+), 31 deletions(-) create mode 100644 tests/utils/__init__.py create mode 100644 tests/utils/test_aio.py create mode 100644 zeroconf/utils/__init__.py create mode 100644 zeroconf/utils/aio.py diff --git a/Makefile b/Makefile index a627c37e..de816c1c 100644 --- a/Makefile +++ b/Makefile @@ -29,14 +29,16 @@ flake8: flake8 --max-line-length=$(MAX_LINE_LENGTH) setup.py examples zeroconf pylint: - pylint zeroconf/__init__.py zeroconf/aio.py zeroconf/asyncio.py + pylint zeroconf .PHONY: black_check black_check: black --check setup.py examples zeroconf mypy: - mypy examples/*.py zeroconf/*.py +# --no-warn-redundant-casts --no-warn-unused-ignores is needed since we support multiple python versions +# We should be able to drop this once python 3.6 goes away + mypy --no-warn-redundant-casts --no-warn-unused-ignores examples/*.py zeroconf test: pytest -v tests diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py new file mode 100644 index 00000000..e38eb583 --- /dev/null +++ b/tests/utils/test_aio.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf.utils.aio.""" + +import asyncio + +import pytest + +from zeroconf.utils import aio as aioutils + + +@pytest.mark.asyncio +async def test_get_running_loop_from_async() -> None: + """Test we can get the event loop.""" + assert isinstance(aioutils.get_running_loop(), asyncio.AbstractEventLoop) + + +def test_get_running_loop_no_loop() -> None: + """Test we get None when there is no loop running.""" + assert aioutils.get_running_loop() is None diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 61119747..92e11228 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -46,6 +46,7 @@ instance_name_from_service_info, millis_to_seconds, ) +from .utils.aio import wait_condition_or_timeout def _get_best_available_queue() -> queue.Queue: @@ -55,35 +56,6 @@ def _get_best_available_queue() -> queue.Queue: return queue.Queue() -# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed -async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: - """Wait for a condition or timeout.""" - loop = asyncio.get_event_loop() - future = loop.create_future() - - def _handle_timeout() -> None: - if not future.done(): - future.set_result(None) - - timer_handle = loop.call_later(timeout, _handle_timeout) - condition_wait = loop.create_task(condition.wait()) - - def _handle_wait_complete(_: asyncio.Task) -> None: - if not future.done(): - future.set_result(None) - - condition_wait.add_done_callback(_handle_wait_complete) - - try: - await future - finally: - timer_handle.cancel() - if not condition_wait.done(): - condition_wait.cancel() - with contextlib.suppress(asyncio.CancelledError): - await condition_wait - - class _AsyncSender(threading.Thread): """A thread to handle sending DNSOutgoing for asyncio.""" diff --git a/zeroconf/utils/__init__.py b/zeroconf/utils/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/zeroconf/utils/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/zeroconf/utils/aio.py b/zeroconf/utils/aio.py new file mode 100644 index 00000000..87a79f26 --- /dev/null +++ b/zeroconf/utils/aio.py @@ -0,0 +1,67 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import asyncio +import contextlib +from typing import Optional, cast + + +# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed +async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: + """Wait for a condition or timeout.""" + loop = asyncio.get_event_loop() + future = loop.create_future() + + def _handle_timeout() -> None: + if not future.done(): + future.set_result(None) + + timer_handle = loop.call_later(timeout, _handle_timeout) + condition_wait = loop.create_task(condition.wait()) + + def _handle_wait_complete(_: asyncio.Task) -> None: + if not future.done(): + future.set_result(None) + + condition_wait.add_done_callback(_handle_wait_complete) + + try: + await future + finally: + timer_handle.cancel() + if not condition_wait.done(): + condition_wait.cancel() + with contextlib.suppress(asyncio.CancelledError): + await condition_wait + + +# Remove the call to _get_running_loop once we drop python 3.6 support +def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: + """Check if an event loop is already running.""" + with contextlib.suppress(RuntimeError): + if hasattr(asyncio, "get_running_loop"): + return cast( + asyncio.AbstractEventLoop, + asyncio.get_running_loop(), # type: ignore # pylint: disable=no-member # noqa + ) + return asyncio._get_running_loop() # pylint: disable=no-member,protected-access + return None From 89d4755106a6c3bced395b0a26eb3082c1268fa1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 11:11:22 -1000 Subject: [PATCH 226/608] Move constants into const.py (#531) --- zeroconf/__init__.py | 173 +++++++++++++------------------------------ zeroconf/const.py | 145 ++++++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+), 121 deletions(-) create mode 100644 zeroconf/const.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 7bbae60a..46251828 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -26,7 +26,6 @@ import itertools import logging import platform -import re import select import socket import struct @@ -41,6 +40,58 @@ import ifaddr +from .const import ( # noqa # import needed for backwards compat + _BROWSER_BACKOFF_LIMIT, + _BROWSER_TIME, + _CACHE_CLEANUP_INTERVAL, + _CHECK_TIME, + _CLASSES, + _CLASS_IN, + _CLASS_NONE, + _CLASS_MASK, + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _DNS_PORT, + _EXPIRE_FULL_TIME_PERCENT, + _EXPIRE_REFRESH_TIME_PERCENT, + _EXPIRE_STALE_TIME_PERCENT, + _FLAGS_AA, + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, + _HAS_ASCII_CONTROL_CHARS, + _HAS_A_TO_Z, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE, + _IPPROTO_IPV6, + _LISTENER_TIME, + _LOCAL_TRAILER, + _MAX_MSG_ABSOLUTE, + _MAX_MSG_TYPICAL, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_ADDR6_BYTES, + _MDNS_ADDR_BYTES, + _MDNS_PORT, + _NONTCP_PROTOCOL_LOCAL_TRAILER, + _REGISTER_TIME, + _SERVICE_TYPE_ENUMERATION_NAME, + _TCP_PROTOCOL_LOCAL_TRAILER, + _TYPES, + _TYPE_A, + _TYPE_AAAA, + _TYPE_ANY, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_PTR, + _TYPE_SOA, + _TYPE_SRV, + _TYPE_TXT, + _UNREGISTER_TIME, +) + __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' __version__ = '0.31.0' @@ -74,126 +125,6 @@ if log.level == logging.NOTSET: log.setLevel(logging.WARN) -# Some timing constants - -_UNREGISTER_TIME = 125 # ms -_CHECK_TIME = 175 # ms -_REGISTER_TIME = 225 # ms -_LISTENER_TIME = 200 # ms -_BROWSER_TIME = 1000 # ms -_BROWSER_BACKOFF_LIMIT = 3600 # s -_CACHE_CLEANUP_INTERVAL = 10000 # ms - -# Some DNS constants - -_MDNS_ADDR = '224.0.0.251' -_MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) -_MDNS_ADDR6 = 'ff02::fb' -_MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) -_MDNS_PORT = 5353 -_DNS_PORT = 53 -_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 -_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 - -_MAX_MSG_TYPICAL = 1460 # unused -_MAX_MSG_ABSOLUTE = 8966 - -_FLAGS_QR_MASK = 0x8000 # query response mask -_FLAGS_QR_QUERY = 0x0000 # query -_FLAGS_QR_RESPONSE = 0x8000 # response - -_FLAGS_AA = 0x0400 # Authoritative answer -_FLAGS_TC = 0x0200 # Truncated -_FLAGS_RD = 0x0100 # Recursion desired -_FLAGS_RA = 0x8000 # Recursion available - -_FLAGS_Z = 0x0040 # Zero -_FLAGS_AD = 0x0020 # Authentic data -_FLAGS_CD = 0x0010 # Checking disabled - -_CLASS_IN = 1 -_CLASS_CS = 2 -_CLASS_CH = 3 -_CLASS_HS = 4 -_CLASS_NONE = 254 -_CLASS_ANY = 255 -_CLASS_MASK = 0x7FFF -_CLASS_UNIQUE = 0x8000 - -_TYPE_A = 1 -_TYPE_NS = 2 -_TYPE_MD = 3 -_TYPE_MF = 4 -_TYPE_CNAME = 5 -_TYPE_SOA = 6 -_TYPE_MB = 7 -_TYPE_MG = 8 -_TYPE_MR = 9 -_TYPE_NULL = 10 -_TYPE_WKS = 11 -_TYPE_PTR = 12 -_TYPE_HINFO = 13 -_TYPE_MINFO = 14 -_TYPE_MX = 15 -_TYPE_TXT = 16 -_TYPE_AAAA = 28 -_TYPE_SRV = 33 -_TYPE_ANY = 255 - -# Mapping constants to names - -_CLASSES = { - _CLASS_IN: "in", - _CLASS_CS: "cs", - _CLASS_CH: "ch", - _CLASS_HS: "hs", - _CLASS_NONE: "none", - _CLASS_ANY: "any", -} - -_TYPES = { - _TYPE_A: "a", - _TYPE_NS: "ns", - _TYPE_MD: "md", - _TYPE_MF: "mf", - _TYPE_CNAME: "cname", - _TYPE_SOA: "soa", - _TYPE_MB: "mb", - _TYPE_MG: "mg", - _TYPE_MR: "mr", - _TYPE_NULL: "null", - _TYPE_WKS: "wks", - _TYPE_PTR: "ptr", - _TYPE_HINFO: "hinfo", - _TYPE_MINFO: "minfo", - _TYPE_MX: "mx", - _TYPE_TXT: "txt", - _TYPE_AAAA: "quada", - _TYPE_SRV: "srv", - _TYPE_ANY: "any", -} - -_HAS_A_TO_Z = re.compile(r'[A-Za-z]') -_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$') -_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') -_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') - -_EXPIRE_FULL_TIME_PERCENT = 100 -_EXPIRE_STALE_TIME_PERCENT = 50 -_EXPIRE_REFRESH_TIME_PERCENT = 75 - -_LOCAL_TRAILER = '.local.' -_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' -_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' - -# https://datatracker.ietf.org/doc/html/rfc6763#section-9 -_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local." - -try: - _IPPROTO_IPV6 = socket.IPPROTO_IPV6 -except AttributeError: - # Sigh: https://bugs.python.org/issue29515 - _IPPROTO_IPV6 = 41 int2byte = struct.Struct(">B").pack diff --git a/zeroconf/const.py b/zeroconf/const.py new file mode 100644 index 00000000..365fee09 --- /dev/null +++ b/zeroconf/const.py @@ -0,0 +1,145 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import re +import socket + +# Some timing constants + +_UNREGISTER_TIME = 125 # ms +_CHECK_TIME = 175 # ms +_REGISTER_TIME = 225 # ms +_LISTENER_TIME = 200 # ms +_BROWSER_TIME = 1000 # ms +_BROWSER_BACKOFF_LIMIT = 3600 # s +_CACHE_CLEANUP_INTERVAL = 10000 # ms + +# Some DNS constants + +_MDNS_ADDR = '224.0.0.251' +_MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) +_MDNS_ADDR6 = 'ff02::fb' +_MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) +_MDNS_PORT = 5353 +_DNS_PORT = 53 +_DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 +_DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 + +_MAX_MSG_TYPICAL = 1460 # unused +_MAX_MSG_ABSOLUTE = 8966 + +_FLAGS_QR_MASK = 0x8000 # query response mask +_FLAGS_QR_QUERY = 0x0000 # query +_FLAGS_QR_RESPONSE = 0x8000 # response + +_FLAGS_AA = 0x0400 # Authoritative answer +_FLAGS_TC = 0x0200 # Truncated +_FLAGS_RD = 0x0100 # Recursion desired +_FLAGS_RA = 0x8000 # Recursion available + +_FLAGS_Z = 0x0040 # Zero +_FLAGS_AD = 0x0020 # Authentic data +_FLAGS_CD = 0x0010 # Checking disabled + +_CLASS_IN = 1 +_CLASS_CS = 2 +_CLASS_CH = 3 +_CLASS_HS = 4 +_CLASS_NONE = 254 +_CLASS_ANY = 255 +_CLASS_MASK = 0x7FFF +_CLASS_UNIQUE = 0x8000 + +_TYPE_A = 1 +_TYPE_NS = 2 +_TYPE_MD = 3 +_TYPE_MF = 4 +_TYPE_CNAME = 5 +_TYPE_SOA = 6 +_TYPE_MB = 7 +_TYPE_MG = 8 +_TYPE_MR = 9 +_TYPE_NULL = 10 +_TYPE_WKS = 11 +_TYPE_PTR = 12 +_TYPE_HINFO = 13 +_TYPE_MINFO = 14 +_TYPE_MX = 15 +_TYPE_TXT = 16 +_TYPE_AAAA = 28 +_TYPE_SRV = 33 +_TYPE_ANY = 255 + +# Mapping constants to names + +_CLASSES = { + _CLASS_IN: "in", + _CLASS_CS: "cs", + _CLASS_CH: "ch", + _CLASS_HS: "hs", + _CLASS_NONE: "none", + _CLASS_ANY: "any", +} + +_TYPES = { + _TYPE_A: "a", + _TYPE_NS: "ns", + _TYPE_MD: "md", + _TYPE_MF: "mf", + _TYPE_CNAME: "cname", + _TYPE_SOA: "soa", + _TYPE_MB: "mb", + _TYPE_MG: "mg", + _TYPE_MR: "mr", + _TYPE_NULL: "null", + _TYPE_WKS: "wks", + _TYPE_PTR: "ptr", + _TYPE_HINFO: "hinfo", + _TYPE_MINFO: "minfo", + _TYPE_MX: "mx", + _TYPE_TXT: "txt", + _TYPE_AAAA: "quada", + _TYPE_SRV: "srv", + _TYPE_ANY: "any", +} + +_HAS_A_TO_Z = re.compile(r'[A-Za-z]') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN = re.compile(r'^[A-Za-z0-9\-]+$') +_HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') +_HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') + +_EXPIRE_FULL_TIME_PERCENT = 100 +_EXPIRE_STALE_TIME_PERCENT = 50 +_EXPIRE_REFRESH_TIME_PERCENT = 75 + +_LOCAL_TRAILER = '.local.' +_TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' +_NONTCP_PROTOCOL_LOCAL_TRAILER = '._udp.local.' + +# https://datatracker.ietf.org/doc/html/rfc6763#section-9 +_SERVICE_TYPE_ENUMERATION_NAME = "_services._dns-sd._udp.local." + +try: + _IPPROTO_IPV6 = socket.IPPROTO_IPV6 +except AttributeError: + # Sigh: https://bugs.python.org/issue29515 + _IPPROTO_IPV6 = 41 From 5100506f896b649e6a6a8e2efb592362cd2644d3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 11:27:38 -1000 Subject: [PATCH 227/608] Move exceptions into zeroconf.exceptions (#532) --- zeroconf/__init__.py | 41 +++++++++------------------------ zeroconf/exceptions.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 31 deletions(-) create mode 100644 zeroconf/exceptions.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 46251828..90439eab 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -91,6 +91,16 @@ _TYPE_TXT, _UNREGISTER_TIME, ) +from .exceptions import ( + AbstractMethodException, + BadTypeInNameException, + Error, + IncomingDecodeError, + NamePartTooLongException, + NonUniqueNameException, + ServiceNameAlreadyRegistered, +) + __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' @@ -306,37 +316,6 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str: return info.name[: -len(service_name) - 1] -# Exceptions - - -class Error(Exception): - pass - - -class IncomingDecodeError(Error): - pass - - -class NonUniqueNameException(Error): - pass - - -class NamePartTooLongException(Error): - pass - - -class AbstractMethodException(Error): - pass - - -class BadTypeInNameException(Error): - pass - - -class ServiceNameAlreadyRegistered(Error): - pass - - # implementation classes diff --git a/zeroconf/exceptions.py b/zeroconf/exceptions.py new file mode 100644 index 00000000..ea468659 --- /dev/null +++ b/zeroconf/exceptions.py @@ -0,0 +1,51 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +# Exceptions + + +class Error(Exception): + pass + + +class IncomingDecodeError(Error): + pass + + +class NonUniqueNameException(Error): + pass + + +class NamePartTooLongException(Error): + pass + + +class AbstractMethodException(Error): + pass + + +class BadTypeInNameException(Error): + pass + + +class ServiceNameAlreadyRegistered(Error): + pass From e2e4eede9117827f47c66a4852dd2d236b46ecda Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 11:37:50 -1000 Subject: [PATCH 228/608] Move logger into zeroconf.logger (#533) --- zeroconf/__init__.py | 36 +-------------------------- zeroconf/logger.py | 58 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 35 deletions(-) create mode 100644 zeroconf/logger.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 90439eab..212b4b7b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -24,7 +24,6 @@ import errno import ipaddress import itertools -import logging import platform import select import socket @@ -100,7 +99,7 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) - +from .logger import QuietLogger, log __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' @@ -129,12 +128,6 @@ ''' ) -log = logging.getLogger(__name__) -log.addHandler(logging.NullHandler()) - -if log.level == logging.NOTSET: - log.setLevel(logging.WARN) - int2byte = struct.Struct(">B").pack @@ -319,33 +312,6 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str: # implementation classes -class QuietLogger: - _seen_logs = {} # type: Dict[str, Union[int, tuple]] - - @classmethod - def log_exception_warning(cls, *logger_data: Any) -> None: - exc_info = sys.exc_info() - exc_str = str(exc_info[1]) - if exc_str not in cls._seen_logs: - # log at warning level the first time this is seen - cls._seen_logs[exc_str] = exc_info - logger = log.warning - else: - logger = log.debug - logger(*(logger_data or ['Exception occurred']), exc_info=True) - - @classmethod - def log_warning_once(cls, *args: Any) -> None: - msg_str = args[0] - if msg_str not in cls._seen_logs: - cls._seen_logs[msg_str] = 0 - logger = log.warning - else: - logger = log.debug - cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 - logger(*args) - - class DNSEntry: """A DNS entry""" diff --git a/zeroconf/logger.py b/zeroconf/logger.py new file mode 100644 index 00000000..b7cb745a --- /dev/null +++ b/zeroconf/logger.py @@ -0,0 +1,58 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import logging +import sys +from typing import Any, Dict, Union, cast + +log = logging.getLogger(__name__.split('.')[0]) +log.addHandler(logging.NullHandler()) + +if log.level == logging.NOTSET: + log.setLevel(logging.WARN) + + +class QuietLogger: + _seen_logs = {} # type: Dict[str, Union[int, tuple]] + + @classmethod + def log_exception_warning(cls, *logger_data: Any) -> None: + exc_info = sys.exc_info() + exc_str = str(exc_info[1]) + if exc_str not in cls._seen_logs: + # log at warning level the first time this is seen + cls._seen_logs[exc_str] = exc_info + logger = log.warning + else: + logger = log.debug + logger(*(logger_data or ['Exception occurred']), exc_info=True) + + @classmethod + def log_warning_once(cls, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args) From 328c1b9acdcd5cafa2df3e5b4b833b908d299500 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 11:58:09 -1000 Subject: [PATCH 229/608] Add missing coverage for QuietLogger (#534) --- tests/test_logger.py | 48 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/test_logger.py diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 00000000..52bf830f --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for logger.py.""" + +from unittest.mock import patch +from zeroconf.logger import QuietLogger + + +def test_log_warning_once(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( + "zeroconf.logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_warning_once("the warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( + "zeroconf.logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_warning_once("the warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls + + +def test_log_exception_warning(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( + "zeroconf.logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_warning("the exception warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( + "zeroconf.logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_warning("the exception warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls From 2976cc2001cbba2c0afc57b9a3d301f382ddac8a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 12:20:23 -1000 Subject: [PATCH 230/608] Avoid making DNSOutgoing aware of the Zeroconf object (#535) - This is not a breaking change since this code has not yet shipped --- zeroconf/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 212b4b7b..8083dbb0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -880,22 +880,22 @@ def add_additional_answer(self, record: DNSRecord) -> None: self.additionals.append(record) def add_question_or_one_cache( - self, zc: "Zeroconf", now: float, name: str, type_: int, class_: int + self, cache: "DNSCache", now: float, name: str, type_: int, class_: int ) -> None: """Add a question if it is not already cached.""" - cached_entry = zc.cache.get_by_details(name, type_, class_) + cached_entry = cache.get_by_details(name, type_, class_) if not cached_entry: self.add_question(DNSQuestion(name, type_, class_)) else: self.add_answer_at_time(cached_entry, now) def add_question_or_all_cache( - self, zc: "Zeroconf", now: float, name: str, type_: int, class_: int + self, cache: "DNSCache", now: float, name: str, type_: int, class_: int ) -> None: """Add a question if it is not already cached. This is currently only used for IPv6 addresses. """ - cached_entries = zc.cache.get_all_by_details(name, type_, class_) + cached_entries = cache.get_all_by_details(name, type_, class_) if not cached_entries: self.add_question(DNSQuestion(name, type_, class_)) return @@ -2131,10 +2131,10 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_one_cache(zc, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc, now, self.server, _TYPE_AAAA, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) return out def __eq__(self, other: object) -> bool: From 7ff810a02e608fae39634be09d6c3ce0a93485b8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 12:30:21 -1000 Subject: [PATCH 231/608] Move time utility functions into zeroconf.utils.time (#536) --- zeroconf/__init__.py | 11 +---------- zeroconf/aio.py | 3 +-- zeroconf/utils/time.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 12 deletions(-) create mode 100644 zeroconf/utils/time.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 8083dbb0..ef5168a4 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -100,6 +100,7 @@ ServiceNameAlreadyRegistered, ) from .logger import QuietLogger, log +from .utils.time import current_time_millis, millis_to_seconds __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' @@ -158,16 +159,6 @@ class IPVersion(enum.Enum): # utility functions -def current_time_millis() -> float: - """Current system time in milliseconds""" - return time.time() * 1000 - - -def millis_to_seconds(millis: float) -> float: - """Convert milliseconds to seconds.""" - return millis / 1000.0 - - def _is_v6_address(addr: bytes) -> bool: return len(addr) == 16 diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 92e11228..55c4c2cb 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -42,11 +42,10 @@ _REGISTER_TIME, _ServiceBrowserBase, _UNREGISTER_TIME, - current_time_millis, instance_name_from_service_info, - millis_to_seconds, ) from .utils.aio import wait_condition_or_timeout +from .utils.time import current_time_millis, millis_to_seconds def _get_best_available_queue() -> queue.Queue: diff --git a/zeroconf/utils/time.py b/zeroconf/utils/time.py new file mode 100644 index 00000000..0ba91ead --- /dev/null +++ b/zeroconf/utils/time.py @@ -0,0 +1,34 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + + +import time + + +def current_time_millis() -> float: + """Current system time in milliseconds""" + return time.time() * 1000 + + +def millis_to_seconds(millis: float) -> float: + """Convert milliseconds to seconds.""" + return millis / 1000.0 From 5af3eb58bfdc1736e6db175c4c6f7c6f2c05b694 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 13:00:13 -1000 Subject: [PATCH 232/608] Breakout network utils into zeroconf.utils.net (#537) --- zeroconf/__init__.py | 346 ++------------------------------------- zeroconf/aio.py | 3 +- zeroconf/utils/net.py | 365 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 380 insertions(+), 334 deletions(-) create mode 100644 zeroconf/utils/net.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ef5168a4..f04b98b0 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -22,7 +22,6 @@ import enum import errno -import ipaddress import itertools import platform import select @@ -37,8 +36,6 @@ from typing import Dict, Iterable, List, Optional, Type, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints -import ifaddr - from .const import ( # noqa # import needed for backwards compat _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -100,6 +97,20 @@ ServiceNameAlreadyRegistered, ) from .logger import QuietLogger, log +from .utils.net import ( # noqa # import needed for backwards compat + add_multicast_member, + can_send_to, + autodetect_ip_version, + create_sockets, + get_all_addresses_v6, + InterfaceChoice, + InterfacesType, + ServiceStateChange, + IPVersion, + _is_v6_address, + _encode_address, + get_all_addresses, +) from .utils.time import current_time_millis, millis_to_seconds __author__ = 'Paul Scott-Murphy, William McBrine' @@ -132,43 +143,9 @@ int2byte = struct.Struct(">B").pack - -@enum.unique -class InterfaceChoice(enum.Enum): - Default = 1 - All = 2 - - -InterfacesType = Union[List[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice] - - -@enum.unique -class ServiceStateChange(enum.Enum): - Added = 1 - Removed = 2 - Updated = 3 - - -@enum.unique -class IPVersion(enum.Enum): - V4Only = 1 - V6Only = 2 - All = 3 - - # utility functions -def _is_v6_address(addr: bytes) -> bool: - return len(addr) == 16 - - -def _encode_address(address: str) -> bytes: - is_ipv6 = ':' in address - address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET - return socket.inet_pton(address_family, address) - - def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches """ Validate a fully qualified service name, instance or subtype. [rfc6763] @@ -2205,301 +2182,6 @@ def find( return tuple(sorted(listener.found_services)) -def get_all_addresses() -> List[str]: - return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4)) - - -def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: - # IPv6 multicast uses positive indexes for interfaces - # TODO: What about multi-address interfaces? - return list( - set((addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6) - ) - - -def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: - ipaddr = ipaddress.ip_address(ip) - for adapter in adapters: - for adapter_ip in adapter.ips: - # IPv6 addresses are represented as tuples - if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: - return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index)) - - raise RuntimeError('No adapter found for IP address %s' % ip) - - -def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: - for adapter in adapters: - if adapter.index == index: - for adapter_ip in adapter.ips: - # IPv6 addresses are represented as tuples - if isinstance(adapter_ip.ip, tuple): - return cast(Tuple[str, int, int], adapter_ip.ip) - - raise RuntimeError('No adapter found for index %s' % index) - - -def ip6_addresses_to_indexes( - interfaces: List[Union[str, int, Tuple[Tuple[str, int, int], int]]] -) -> List[Tuple[Tuple[str, int, int], int]]: - """Convert IPv6 interface addresses to interface indexes. - - IPv4 addresses are ignored. - - :param interfaces: List of IP addresses and indexes. - :returns: List of indexes. - """ - result = [] - adapters = ifaddr.get_adapters() - - for iface in interfaces: - if isinstance(iface, int): - result.append((interface_index_to_ip6_address(adapters, iface), iface)) - elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: - result.append(ip6_to_address_and_index(adapters, iface)) - - return result - - -def normalize_interface_choice( - choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only -) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: - """Convert the interfaces choice into internal representation. - - :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). - :param ip_address: IP version to use (ignored if `choice` is a list). - :returns: List of IP addresses (for IPv4) and indexes (for IPv6). - """ - result = [] # type: List[Union[str, Tuple[Tuple[str, int, int], int]]] - if choice is InterfaceChoice.Default: - if ip_version != IPVersion.V4Only: - # IPv6 multicast uses interface 0 to mean the default - result.append((('', 0, 0), 0)) - if ip_version != IPVersion.V6Only: - result.append('0.0.0.0') - elif choice is InterfaceChoice.All: - if ip_version != IPVersion.V4Only: - result.extend(get_all_addresses_v6()) - if ip_version != IPVersion.V6Only: - result.extend(get_all_addresses()) - if not result: - raise RuntimeError( - 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version - ) - elif isinstance(choice, list): - # First, take IPv4 addresses. - result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] - # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. - result += ip6_addresses_to_indexes(choice) - else: - raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice) - return result - - -def new_socket( # pylint: disable=too-many-branches - bind_addr: Union[Tuple[str], Tuple[str, int, int]], - port: int = _MDNS_PORT, - ip_version: IPVersion = IPVersion.V4Only, - apple_p2p: bool = False, -) -> socket.socket: - log.debug( - 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r', - port, - ip_version, - apple_p2p, - bind_addr, - ) - if ip_version == IPVersion.V4Only: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - else: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - - if ip_version == IPVersion.All: - # make V6 sockets work for both V4 and V6 (required for Windows) - try: - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) - except OSError: - log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') - raise - - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - - # SO_REUSEADDR should be equivalent to SO_REUSEPORT for - # multicast UDP sockets (p 731, "TCP/IP Illustrated, - # Volume 2"), but some BSD-derived systems require - # SO_REUSEPORT to be specified explicitly. Also, not all - # versions of Python have SO_REUSEPORT available. - # Catch OSError and socket.error for kernel versions <3.9 because lacking - # SO_REUSEPORT support. - try: - reuseport = socket.SO_REUSEPORT - except AttributeError: - pass - else: - try: - s.setsockopt(socket.SOL_SOCKET, reuseport, 1) - except OSError as err: - if err.errno != errno.ENOPROTOOPT: - raise - - if port == _MDNS_PORT: - ttl = struct.pack(b'B', 255) - loop = struct.pack(b'B', 1) - if ip_version != IPVersion.V6Only: - # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and - # IP_MULTICAST_LOOP socket options as an unsigned char. - try: - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) - except socket.error as e: - if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS - raise - if ip_version != IPVersion.V4Only: - # However, char doesn't work here (at least on Linux) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) - - if apple_p2p: - # SO_RECV_ANYIF = 0x1104 - # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h - s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) - - s.bind((bind_addr[0], port, *bind_addr[1:])) - log.debug('Created socket %s', s) - return s - - -def add_multicast_member( - listen_socket: socket.socket, - interface: Union[str, Tuple[Tuple[str, int, int], int]], -) -> bool: - # This is based on assumptions in normalize_interface_choice - is_v6 = isinstance(interface, tuple) - err_einval = {errno.EINVAL} - if sys.platform == 'win32': - # No WSAEINVAL definition in typeshed - err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member - log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) - try: - if is_v6: - iface_bin = struct.pack('@I', cast(int, interface[1])) - _value = _MDNS_ADDR6_BYTES + iface_bin - listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) - else: - _value = _MDNS_ADDR_BYTES + socket.inet_aton(cast(str, interface)) - listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) - except socket.error as e: - _errno = get_errno(e) - if _errno == errno.EADDRINUSE: - log.info( - 'Address in use when adding %s to multicast group, ' - 'it is expected to happen on some systems', - interface, - ) - return False - if _errno == errno.EADDRNOTAVAIL: - log.info( - 'Address not available when adding %s to multicast ' - 'group, it is expected to happen on some systems', - interface, - ) - return False - if _errno in err_einval: - log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) - return False - raise - return True - - -def new_respond_socket( - interface: Union[str, Tuple[Tuple[str, int, int], int]], - apple_p2p: bool = False, -) -> Optional[socket.socket]: - is_v6 = isinstance(interface, tuple) - respond_socket = new_socket( - ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), - apple_p2p=apple_p2p, - bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), - ) - log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) - if is_v6: - iface_bin = struct.pack('@I', cast(int, interface[1])) - respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) - else: - respond_socket.setsockopt( - socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface)) - ) - return respond_socket - - -def create_sockets( - interfaces: InterfacesType = InterfaceChoice.All, - unicast: bool = False, - ip_version: IPVersion = IPVersion.V4Only, - apple_p2p: bool = False, -) -> Tuple[Optional[socket.socket], List[socket.socket]]: - if unicast: - listen_socket = None - else: - listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',)) - - normalized_interfaces = normalize_interface_choice(interfaces, ip_version) - - # If we are using InterfaceChoice.Default we can use - # a single socket to listen and respond. - if not unicast and interfaces is InterfaceChoice.Default: - for i in normalized_interfaces: - add_multicast_member(cast(socket.socket, listen_socket), i) - return listen_socket, [cast(socket.socket, listen_socket)] - - respond_sockets = [] - - for i in normalized_interfaces: - if not unicast: - if add_multicast_member(cast(socket.socket, listen_socket), i): - respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) - else: - respond_socket = None - else: - respond_socket = new_socket( - port=0, - ip_version=ip_version, - apple_p2p=apple_p2p, - bind_addr=i[0] if isinstance(i, tuple) else (i,), - ) - - if respond_socket is not None: - respond_sockets.append(respond_socket) - - return listen_socket, respond_sockets - - -def get_errno(e: Exception) -> int: - assert isinstance(e, socket.error) - return cast(int, e.args[0]) - - -def can_send_to(sock: socket.socket, address: str) -> bool: - addr = ipaddress.ip_address(address) - return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) - - -def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: - """Auto detect the IP version when it is not provided.""" - if isinstance(interfaces, list): - has_v6 = any( - isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) - for i in interfaces - ) - has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) - if has_v4 and has_v6: - return IPVersion.All - if has_v6: - return IPVersion.V6Only - - return IPVersion.V4Only - - class ServiceRegistry: """A registry to keep track of services. diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 55c4c2cb..1df18243 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -29,8 +29,6 @@ from . import ( DNSOutgoing, IPVersion, - InterfaceChoice, - InterfacesType, NonUniqueNameException, NotifyListener, ServiceInfo, @@ -45,6 +43,7 @@ instance_name_from_service_info, ) from .utils.aio import wait_condition_or_timeout +from .utils.net import InterfaceChoice, InterfacesType from .utils.time import current_time_millis, millis_to_seconds diff --git a/zeroconf/utils/net.py b/zeroconf/utils/net.py new file mode 100644 index 00000000..5ea49924 --- /dev/null +++ b/zeroconf/utils/net.py @@ -0,0 +1,365 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import errno +import ipaddress +import socket +import struct +import sys +from typing import Any, List, Optional, Tuple, Union, cast + +import ifaddr + +from ..const import _IPPROTO_IPV6, _MDNS_ADDR6_BYTES, _MDNS_ADDR_BYTES, _MDNS_PORT +from ..logger import log + + +@enum.unique +class InterfaceChoice(enum.Enum): + Default = 1 + All = 2 + + +InterfacesType = Union[List[Union[str, int, Tuple[Tuple[str, int, int], int]]], InterfaceChoice] + + +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + +@enum.unique +class IPVersion(enum.Enum): + V4Only = 1 + V6Only = 2 + All = 3 + + +# utility functions + + +def _is_v6_address(addr: bytes) -> bool: + return len(addr) == 16 + + +def _encode_address(address: str) -> bytes: + is_ipv6 = ':' in address + address_family = socket.AF_INET6 if is_ipv6 else socket.AF_INET + return socket.inet_pton(address_family, address) + + +def get_all_addresses() -> List[str]: + return list(set(addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4)) + + +def get_all_addresses_v6() -> List[Tuple[Tuple[str, int, int], int]]: + # IPv6 multicast uses positive indexes for interfaces + # TODO: What about multi-address interfaces? + return list( + set((addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6) + ) + + +def ip6_to_address_and_index(adapters: List[Any], ip: str) -> Tuple[Tuple[str, int, int], int]: + ipaddr = ipaddress.ip_address(ip) + for adapter in adapters: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple) and ipaddress.ip_address(adapter_ip.ip[0]) == ipaddr: + return (cast(Tuple[str, int, int], adapter_ip.ip), cast(int, adapter.index)) + + raise RuntimeError('No adapter found for IP address %s' % ip) + + +def interface_index_to_ip6_address(adapters: List[Any], index: int) -> Tuple[str, int, int]: + for adapter in adapters: + if adapter.index == index: + for adapter_ip in adapter.ips: + # IPv6 addresses are represented as tuples + if isinstance(adapter_ip.ip, tuple): + return cast(Tuple[str, int, int], adapter_ip.ip) + + raise RuntimeError('No adapter found for index %s' % index) + + +def ip6_addresses_to_indexes( + interfaces: List[Union[str, int, Tuple[Tuple[str, int, int], int]]] +) -> List[Tuple[Tuple[str, int, int], int]]: + """Convert IPv6 interface addresses to interface indexes. + + IPv4 addresses are ignored. + + :param interfaces: List of IP addresses and indexes. + :returns: List of indexes. + """ + result = [] + adapters = ifaddr.get_adapters() + + for iface in interfaces: + if isinstance(iface, int): + result.append((interface_index_to_ip6_address(adapters, iface), iface)) + elif isinstance(iface, str) and ipaddress.ip_address(iface).version == 6: + result.append(ip6_to_address_and_index(adapters, iface)) + + return result + + +def normalize_interface_choice( + choice: InterfacesType, ip_version: IPVersion = IPVersion.V4Only +) -> List[Union[str, Tuple[Tuple[str, int, int], int]]]: + """Convert the interfaces choice into internal representation. + + :param choice: `InterfaceChoice` or list of interface addresses or indexes (IPv6 only). + :param ip_address: IP version to use (ignored if `choice` is a list). + :returns: List of IP addresses (for IPv4) and indexes (for IPv6). + """ + result = [] # type: List[Union[str, Tuple[Tuple[str, int, int], int]]] + if choice is InterfaceChoice.Default: + if ip_version != IPVersion.V4Only: + # IPv6 multicast uses interface 0 to mean the default + result.append((('', 0, 0), 0)) + if ip_version != IPVersion.V6Only: + result.append('0.0.0.0') + elif choice is InterfaceChoice.All: + if ip_version != IPVersion.V4Only: + result.extend(get_all_addresses_v6()) + if ip_version != IPVersion.V6Only: + result.extend(get_all_addresses()) + if not result: + raise RuntimeError( + 'No interfaces to listen on, check that any interfaces have IP version %s' % ip_version + ) + elif isinstance(choice, list): + # First, take IPv4 addresses. + result = [i for i in choice if isinstance(i, str) and ipaddress.ip_address(i).version == 4] + # Unlike IP_ADD_MEMBERSHIP, IPV6_JOIN_GROUP requires interface indexes. + result += ip6_addresses_to_indexes(choice) + else: + raise TypeError("choice must be a list or InterfaceChoice, got %r" % choice) + return result + + +def new_socket( # pylint: disable=too-many-branches + bind_addr: Union[Tuple[str], Tuple[str, int, int]], + port: int = _MDNS_PORT, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> socket.socket: + log.debug( + 'Creating new socket with port %s, ip_version %s, apple_p2p %s and bind_addr %r', + port, + ip_version, + apple_p2p, + bind_addr, + ) + if ip_version == IPVersion.V4Only: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + else: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + + if ip_version == IPVersion.All: + # make V6 sockets work for both V4 and V6 (required for Windows) + try: + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + except OSError: + log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') + raise + + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # SO_REUSEADDR should be equivalent to SO_REUSEPORT for + # multicast UDP sockets (p 731, "TCP/IP Illustrated, + # Volume 2"), but some BSD-derived systems require + # SO_REUSEPORT to be specified explicitly. Also, not all + # versions of Python have SO_REUSEPORT available. + # Catch OSError and socket.error for kernel versions <3.9 because lacking + # SO_REUSEPORT support. + try: + reuseport = socket.SO_REUSEPORT + except AttributeError: + pass + else: + try: + s.setsockopt(socket.SOL_SOCKET, reuseport, 1) + except OSError as err: + if err.errno != errno.ENOPROTOOPT: + raise + + if port == _MDNS_PORT: + ttl = struct.pack(b'B', 255) + loop = struct.pack(b'B', 1) + if ip_version != IPVersion.V6Only: + # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and + # IP_MULTICAST_LOOP socket options as an unsigned char. + try: + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + except socket.error as e: + if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS + raise + if ip_version != IPVersion.V4Only: + # However, char doesn't work here (at least on Linux) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) + + if apple_p2p: + # SO_RECV_ANYIF = 0x1104 + # https://opensource.apple.com/source/xnu/xnu-4570.41.2/bsd/sys/socket.h + s.setsockopt(socket.SOL_SOCKET, 0x1104, 1) + + s.bind((bind_addr[0], port, *bind_addr[1:])) + log.debug('Created socket %s', s) + return s + + +def add_multicast_member( + listen_socket: socket.socket, + interface: Union[str, Tuple[Tuple[str, int, int], int]], +) -> bool: + # This is based on assumptions in normalize_interface_choice + is_v6 = isinstance(interface, tuple) + err_einval = {errno.EINVAL} + if sys.platform == 'win32': + # No WSAEINVAL definition in typeshed + err_einval |= {cast(Any, errno).WSAEINVAL} # pylint: disable=no-member + log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) + try: + if is_v6: + iface_bin = struct.pack('@I', cast(int, interface[1])) + _value = _MDNS_ADDR6_BYTES + iface_bin + listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) + else: + _value = _MDNS_ADDR_BYTES + socket.inet_aton(cast(str, interface)) + listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) + except socket.error as e: + _errno = get_errno(e) + if _errno == errno.EADDRINUSE: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen on some systems', + interface, + ) + return False + if _errno == errno.EADDRNOTAVAIL: + log.info( + 'Address not available when adding %s to multicast ' + 'group, it is expected to happen on some systems', + interface, + ) + return False + if _errno in err_einval: + log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) + return False + raise + return True + + +def new_respond_socket( + interface: Union[str, Tuple[Tuple[str, int, int], int]], + apple_p2p: bool = False, +) -> Optional[socket.socket]: + is_v6 = isinstance(interface, tuple) + respond_socket = new_socket( + ip_version=(IPVersion.V6Only if is_v6 else IPVersion.V4Only), + apple_p2p=apple_p2p, + bind_addr=cast(Tuple[Tuple[str, int, int], int], interface)[0] if is_v6 else (cast(str, interface),), + ) + log.debug('Configuring socket %s with multicast interface %s', respond_socket, interface) + if is_v6: + iface_bin = struct.pack('@I', cast(int, interface[1])) + respond_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, iface_bin) + else: + respond_socket.setsockopt( + socket.IPPROTO_IP, socket.IP_MULTICAST_IF, socket.inet_aton(cast(str, interface)) + ) + return respond_socket + + +def create_sockets( + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: IPVersion = IPVersion.V4Only, + apple_p2p: bool = False, +) -> Tuple[Optional[socket.socket], List[socket.socket]]: + if unicast: + listen_socket = None + else: + listen_socket = new_socket(ip_version=ip_version, apple_p2p=apple_p2p, bind_addr=('',)) + + normalized_interfaces = normalize_interface_choice(interfaces, ip_version) + + # If we are using InterfaceChoice.Default we can use + # a single socket to listen and respond. + if not unicast and interfaces is InterfaceChoice.Default: + for i in normalized_interfaces: + add_multicast_member(cast(socket.socket, listen_socket), i) + return listen_socket, [cast(socket.socket, listen_socket)] + + respond_sockets = [] + + for i in normalized_interfaces: + if not unicast: + if add_multicast_member(cast(socket.socket, listen_socket), i): + respond_socket = new_respond_socket(i, apple_p2p=apple_p2p) + else: + respond_socket = None + else: + respond_socket = new_socket( + port=0, + ip_version=ip_version, + apple_p2p=apple_p2p, + bind_addr=i[0] if isinstance(i, tuple) else (i,), + ) + + if respond_socket is not None: + respond_sockets.append(respond_socket) + + return listen_socket, respond_sockets + + +def get_errno(e: Exception) -> int: + assert isinstance(e, socket.error) + return cast(int, e.args[0]) + + +def can_send_to(sock: socket.socket, address: str) -> bool: + addr = ipaddress.ip_address(address) + return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) + + +def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: + """Auto detect the IP version when it is not provided.""" + if isinstance(interfaces, list): + has_v6 = any( + isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6) + for i in interfaces + ) + has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces) + if has_v4 and has_v6: + return IPVersion.All + if has_v6: + return IPVersion.V6Only + + return IPVersion.V4Only From 6af42b54640ebba541302bfcf7688b3926453b15 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 13:11:12 -1000 Subject: [PATCH 233/608] Move int2byte to zeroconf.utils.struct (#540) --- zeroconf/__init__.py | 3 +-- zeroconf/utils/struct.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 zeroconf/utils/struct.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f04b98b0..1b335679 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -111,6 +111,7 @@ _encode_address, get_all_addresses, ) +from .utils.struct import int2byte from .utils.time import current_time_millis, millis_to_seconds __author__ = 'Paul Scott-Murphy, William McBrine' @@ -141,8 +142,6 @@ ) -int2byte = struct.Struct(">B").pack - # utility functions diff --git a/zeroconf/utils/struct.py b/zeroconf/utils/struct.py new file mode 100644 index 00000000..6ec99988 --- /dev/null +++ b/zeroconf/utils/struct.py @@ -0,0 +1,25 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import struct + +int2byte = struct.Struct(">B").pack From 8733cad2eae71ebdf94ecadc6fd5439882477235 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 13:15:37 -1000 Subject: [PATCH 234/608] Update zeroconf.aio import locations (#539) --- zeroconf/aio.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 1df18243..9a23be93 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -28,22 +28,16 @@ from . import ( DNSOutgoing, - IPVersion, - NonUniqueNameException, NotifyListener, ServiceInfo, Zeroconf, - _BROWSER_TIME, - _CHECK_TIME, - _LISTENER_TIME, - _MDNS_PORT, - _REGISTER_TIME, _ServiceBrowserBase, - _UNREGISTER_TIME, instance_name_from_service_info, ) +from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME +from .exceptions import NonUniqueNameException from .utils.aio import wait_condition_or_timeout -from .utils.net import InterfaceChoice, InterfacesType +from .utils.net import IPVersion, InterfaceChoice, InterfacesType from .utils.time import current_time_millis, millis_to_seconds From 1e3e7df8b7fdacd90cf5d864411e5db5a915be94 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 13:22:45 -1000 Subject: [PATCH 235/608] Relocate DNS classes to zeroconf.dns (#541) --- zeroconf/__init__.py | 996 +--------------------------------------- zeroconf/dns.py | 1030 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1045 insertions(+), 981 deletions(-) create mode 100644 zeroconf/dns.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1b335679..2ae88431 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -20,20 +20,18 @@ USA """ -import enum import errno import itertools import platform import select import socket -import struct import sys import threading import time import warnings from collections import OrderedDict from types import TracebackType # noqa # used in type hints -from typing import Dict, Iterable, List, Optional, Type, Union, cast +from typing import Dict, List, Optional, Type, Union, cast from typing import Any, Callable, Set, Tuple # noqa # used in type hints from .const import ( # noqa # import needed for backwards compat @@ -87,7 +85,20 @@ _TYPE_TXT, _UNREGISTER_TIME, ) -from .exceptions import ( +from .dns import ( # noqa # import needed for backwards compat + DNSAddress, + DNSCache, + DNSEntry, + DNSHinfo, + DNSIncoming, + DNSOutgoing, + DNSPointer, + DNSQuestion, + DNSRecord, + DNSService, + DNSText, +) +from .exceptions import ( # noqa # import needed for backwards compat AbstractMethodException, BadTypeInNameException, Error, @@ -279,983 +290,6 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str: # implementation classes -class DNSEntry: - - """A DNS entry""" - - def __init__(self, name: str, type_: int, class_: int) -> None: - self.key = name.lower() - self.name = name - self.type = type_ - self.class_ = class_ & _CLASS_MASK - self.unique = (class_ & _CLASS_UNIQUE) != 0 - - def __eq__(self, other: Any) -> bool: - """Equality test on key (lowercase name), type, and class""" - return ( - self.key == other.key - and self.type == other.type - and self.class_ == other.class_ - and isinstance(other, DNSEntry) - ) - - @staticmethod - def get_class_(class_: int) -> str: - """Class accessor""" - return _CLASSES.get(class_, "?(%s)" % class_) - - @staticmethod - def get_type(t: int) -> str: - """Type accessor""" - return _TYPES.get(t, "?(%s)" % t) - - def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: - """String representation with additional information""" - result = "%s[%s,%s" % (hdr, self.get_type(self.type), self.get_class_(self.class_)) - if self.unique: - result += "-unique," - else: - result += "," - result += self.name - if other is not None: - result += "]=%s" % cast(Any, other) - else: - result += "]" - return result - - -class DNSQuestion(DNSEntry): - - """A DNS question entry""" - - def __init__(self, name: str, type_: int, class_: int) -> None: - DNSEntry.__init__(self, name, type_, class_) - - def answered_by(self, rec: 'DNSRecord') -> bool: - """Returns true if the question is answered by the record""" - return ( - self.class_ == rec.class_ - and (self.type == rec.type or self.type == _TYPE_ANY) - and self.name == rec.name - ) - - def __repr__(self) -> str: - """String representation""" - return DNSEntry.entry_to_string(self, "question", None) - - -class DNSRecord(DNSEntry): - - """A DNS record - like a DNS entry, but has a TTL""" - - # TODO: Switch to just int ttl - def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: - DNSEntry.__init__(self, name, type_, class_) - self.ttl = ttl - self.created = current_time_millis() - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - - def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use - """Abstract method""" - raise AbstractMethodException - - def suppressed_by(self, msg: 'DNSIncoming') -> bool: - """Returns true if any answer in a message can suffice for the - information held in this record.""" - for record in msg.answers: - if self.suppressed_by_answer(record): - return True - return False - - def suppressed_by_answer(self, other: 'DNSRecord') -> bool: - """Returns true if another record has same name, type and class, - and if its TTL is at least half of this record's.""" - return self == other and other.ttl > (self.ttl / 2) - - def get_expiration_time(self, percent: int) -> float: - """Returns the time at which this record will have expired - by a certain percentage.""" - return self.created + (percent * self.ttl * 10) - - # TODO: Switch to just int here - def get_remaining_ttl(self, now: float) -> Union[int, float]: - """Returns the remaining TTL in seconds.""" - return max(0, millis_to_seconds(self._expiration_time - now)) - - def is_expired(self, now: float) -> bool: - """Returns true if this record has expired.""" - return self._expiration_time <= now - - def is_stale(self, now: float) -> bool: - """Returns true if this record is at least half way expired.""" - return self._stale_time <= now - - def reset_ttl(self, other: 'DNSRecord') -> None: - """Sets this record's TTL and created time to that of - another record.""" - self.created = other.created - self.ttl = other.ttl - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - - def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use - """Abstract method""" - raise AbstractMethodException - - def to_string(self, other: Union[bytes, str]) -> str: - """String representation with additional information""" - arg = "%s/%s,%s" % (self.ttl, int(self.get_remaining_ttl(current_time_millis())), cast(Any, other)) - return DNSEntry.entry_to_string(self, "record", arg) - - -class DNSAddress(DNSRecord): - - """A DNS address record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.address = address - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_string(self.address) - - def __eq__(self, other: Any) -> bool: - """Tests equality on address""" - return ( - isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address - ) - - def __repr__(self) -> str: - """String representation""" - try: - return self.to_string( - socket.inet_ntop( - socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address - ) - ) - except (ValueError, OSError): - return self.to_string(str(self.address)) - - -class DNSHinfo(DNSRecord): - - """A DNS host information record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.cpu = cpu - self.os = os - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_character_string(self.cpu.encode('utf-8')) - out.write_character_string(self.os.encode('utf-8')) - - def __eq__(self, other: Any) -> bool: - """Tests equality on cpu and os""" - return ( - isinstance(other, DNSHinfo) - and DNSEntry.__eq__(self, other) - and self.cpu == other.cpu - and self.os == other.os - ) - - def __repr__(self) -> str: - """String representation""" - return self.to_string(self.cpu + " " + self.os) - - -class DNSPointer(DNSRecord): - - """A DNS pointer record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.alias = alias - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_name(self.alias) - - def __eq__(self, other: Any) -> bool: - """Tests equality on alias""" - return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) - - def __repr__(self) -> str: - """String representation""" - return self.to_string(self.alias) - - -class DNSText(DNSRecord): - - """A DNS text record""" - - def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: - assert isinstance(text, (bytes, type(None))) - DNSRecord.__init__(self, name, type_, class_, ttl) - self.text = text - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_string(self.text) - - def __eq__(self, other: Any) -> bool: - """Tests equality on text""" - return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) - - def __repr__(self) -> str: - """String representation""" - if len(self.text) > 10: - return self.to_string(self.text[:7]) + "..." - return self.to_string(self.text) - - -class DNSService(DNSRecord): - - """A DNS service record""" - - def __init__( - self, - name: str, - type_: int, - class_: int, - ttl: Union[float, int], - priority: int, - weight: int, - port: int, - server: str, - ) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) - self.priority = priority - self.weight = weight - self.port = port - self.server = server - - def write(self, out: 'DNSOutgoing') -> None: - """Used in constructing an outgoing packet""" - out.write_short(self.priority) - out.write_short(self.weight) - out.write_short(self.port) - out.write_name(self.server) - - def __eq__(self, other: Any) -> bool: - """Tests equality on priority, weight, port and server""" - return ( - isinstance(other, DNSService) - and self.priority == other.priority - and self.weight == other.weight - and self.port == other.port - and self.server == other.server - and DNSEntry.__eq__(self, other) - ) - - def __repr__(self) -> str: - """String representation""" - return self.to_string("%s:%s" % (self.server, self.port)) - - -class DNSMessage: - """A base class for DNS messages.""" - - def __init__(self, flags: int) -> None: - """Construct a DNS message.""" - self.flags = flags - - def is_query(self) -> bool: - """Returns true if this is a query.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY - - def is_response(self) -> bool: - """Returns true if this is a response.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE - - -class DNSIncoming(DNSMessage, QuietLogger): - - """Object representation of an incoming DNS packet""" - - def __init__(self, data: bytes) -> None: - """Constructor from string holding bytes of packet""" - super().__init__(0) - self.offset = 0 - self.data = data - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[DNSRecord] - self.id = 0 - self.num_questions = 0 - self.num_answers = 0 - self.num_authorities = 0 - self.num_additionals = 0 - self.valid = False - - try: - self.read_header() - self.read_questions() - self.read_others() - self.valid = True - - except (IndexError, struct.error, IncomingDecodeError): - self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'id=%s' % self.id, - 'flags=%s' % self.flags, - 'n_q=%s' % self.num_questions, - 'n_ans=%s' % self.num_answers, - 'n_auth=%s' % self.num_authorities, - 'n_add=%s' % self.num_additionals, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - ] - ) - - def unpack(self, format_: bytes) -> tuple: - length = struct.calcsize(format_) - info = struct.unpack(format_, self.data[self.offset : self.offset + length]) - self.offset += length - return info - - def read_header(self) -> None: - """Reads header portion of packet""" - ( - self.id, - self.flags, - self.num_questions, - self.num_answers, - self.num_authorities, - self.num_additionals, - ) = self.unpack(b'!6H') - - def read_questions(self) -> None: - """Reads questions section of packet""" - for _ in range(self.num_questions): - name = self.read_name() - type_, class_ = self.unpack(b'!HH') - - question = DNSQuestion(name, type_, class_) - self.questions.append(question) - - # def read_int(self): - # """Reads an integer from the packet""" - # return self.unpack(b'!I')[0] - - def read_character_string(self) -> bytes: - """Reads a character string from the packet""" - length = self.data[self.offset] - self.offset += 1 - return self.read_string(length) - - def read_string(self, length: int) -> bytes: - """Reads a string of a given length from the packet""" - info = self.data[self.offset : self.offset + length] - self.offset += length - return info - - def read_unsigned_short(self) -> int: - """Reads an unsigned short from the packet""" - return cast(int, self.unpack(b'!H')[0]) - - def read_others(self) -> None: - """Reads the answers, authorities and additionals section of the - packet""" - n = self.num_answers + self.num_authorities + self.num_additionals - for _ in range(n): - domain = self.read_name() - type_, class_, ttl, length = self.unpack(b'!HHiH') - - rec = None # type: Optional[DNSRecord] - if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) - elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) - elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) - elif type_ == _TYPE_SRV: - rec = DNSService( - domain, - type_, - class_, - ttl, - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_name(), - ) - elif type_ == _TYPE_HINFO: - rec = DNSHinfo( - domain, - type_, - class_, - ttl, - self.read_character_string().decode('utf-8'), - self.read_character_string().decode('utf-8'), - ) - elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) - else: - # Try to ignore types we don't know about - # Skip the payload for the resource record so the next - # records can be parsed correctly - self.offset += length - - if rec is not None: - self.answers.append(rec) - - def read_utf(self, offset: int, length: int) -> str: - """Reads a UTF-8 string of a given length from the packet""" - return str(self.data[offset : offset + length], 'utf-8', 'replace') - - def read_name(self) -> str: - """Reads a domain name from the packet""" - result = '' - off = self.offset - next_ = -1 - first = off - - while True: - length = self.data[off] - off += 1 - if length == 0: - break - t = length & 0xC0 - if t == 0x00: - result += self.read_utf(off, length) + '.' - off += length - elif t == 0xC0: - if next_ < 0: - next_ = off + 1 - off = ((length & 0x3F) << 8) | self.data[off] - if off >= first: - raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) - first = off - else: - raise IncomingDecodeError("Bad domain name at %s" % (off,)) - - if next_ >= 0: - self.offset = next_ - else: - self.offset = off - - return result - - -class DNSOutgoing(DNSMessage): - - """Object representation of an outgoing packet""" - - def __init__(self, flags: int, multicast: bool = True) -> None: - super().__init__(flags) - self.finished = False - self.id = 0 - self.multicast = multicast - self.packets_data = [] # type: List[bytes] - - # these 3 are per-packet -- see also reset_for_next_packet() - self.names = {} # type: Dict[str, int] - self.data = [] # type: List[bytes] - self.size = 12 - self.allow_long = True - - self.state = self.State.init - - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[Tuple[DNSRecord, float]] - self.authorities = [] # type: List[DNSPointer] - self.additionals = [] # type: List[DNSRecord] - - def reset_for_next_packet(self) -> None: - self.names = {} - self.data = [] - self.size = 12 - self.allow_long = True - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'multicast=%s' % self.multicast, - 'flags=%s' % self.flags, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - 'authorities=%s' % self.authorities, - 'additionals=%s' % self.additionals, - ] - ) - - class State(enum.Enum): - init = 0 - finished = 1 - - def add_question(self, record: DNSQuestion) -> None: - """Adds a question""" - self.questions.append(record) - - def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: - """Adds an answer""" - if not record.suppressed_by(inp): - self.add_answer_at_time(record, 0) - - def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: - """Adds an answer if it does not expire by a certain time""" - if record is not None: - if now == 0 or not record.is_expired(now): - self.answers.append((record, now)) - - def add_authorative_answer(self, record: DNSPointer) -> None: - """Adds an authoritative answer""" - self.authorities.append(record) - - def add_additional_answer(self, record: DNSRecord) -> None: - """Adds an additional answer - - From: RFC 6763, DNS-Based Service Discovery, February 2013 - - 12. DNS Additional Record Generation - - DNS has an efficiency feature whereby a DNS server may place - additional records in the additional section of the DNS message. - These additional records are records that the client did not - explicitly request, but the server has reasonable grounds to expect - that the client might request them shortly, so including them can - save the client from having to issue additional queries. - - This section recommends which additional records SHOULD be generated - to improve network efficiency, for both Unicast and Multicast DNS-SD - responses. - - 12.1. PTR Records - - When including a DNS-SD Service Instance Enumeration or Selective - Instance Enumeration (subtype) PTR record in a response packet, the - server/responder SHOULD include the following additional records: - - o The SRV record(s) named in the PTR rdata. - o The TXT record(s) named in the PTR rdata. - o All address records (type "A" and "AAAA") named in the SRV rdata. - - 12.2. SRV Records - - When including an SRV record in a response packet, the - server/responder SHOULD include the following additional records: - - o All address records (type "A" and "AAAA") named in the SRV rdata. - - """ - self.additionals.append(record) - - def add_question_or_one_cache( - self, cache: "DNSCache", now: float, name: str, type_: int, class_: int - ) -> None: - """Add a question if it is not already cached.""" - cached_entry = cache.get_by_details(name, type_, class_) - if not cached_entry: - self.add_question(DNSQuestion(name, type_, class_)) - else: - self.add_answer_at_time(cached_entry, now) - - def add_question_or_all_cache( - self, cache: "DNSCache", now: float, name: str, type_: int, class_: int - ) -> None: - """Add a question if it is not already cached. - This is currently only used for IPv6 addresses. - """ - cached_entries = cache.get_all_by_details(name, type_, class_) - if not cached_entries: - self.add_question(DNSQuestion(name, type_, class_)) - return - for cached_entry in cached_entries: - self.add_answer_at_time(cached_entry, now) - - def pack(self, format_: Union[bytes, str], value: Any) -> None: - self.data.append(struct.pack(format_, value)) - self.size += struct.calcsize(format_) - - def write_byte(self, value: int) -> None: - """Writes a single byte to the packet""" - self.pack(b'!c', int2byte(value)) - - def insert_short_at_start(self, value: int) -> None: - """Inserts an unsigned short at the start of the packet""" - self.data.insert(0, struct.pack(b'!H', value)) - - def replace_short(self, index: int, value: int) -> None: - """Replaces an unsigned short in a certain position in the packet""" - self.data[index] = struct.pack(b'!H', value) - - def write_short(self, value: int) -> None: - """Writes an unsigned short to the packet""" - self.pack(b'!H', value) - - def write_int(self, value: Union[float, int]) -> None: - """Writes an unsigned integer to the packet""" - self.pack(b'!I', int(value)) - - def write_string(self, value: bytes) -> None: - """Writes a string to the packet""" - assert isinstance(value, bytes) - self.data.append(value) - self.size += len(value) - - def write_utf(self, s: str) -> None: - """Writes a UTF-8 string of a given length to the packet""" - utfstr = s.encode('utf-8') - length = len(utfstr) - if length > 64: - raise NamePartTooLongException - self.write_byte(length) - self.write_string(utfstr) - - def write_character_string(self, value: bytes) -> None: - assert isinstance(value, bytes) - length = len(value) - if length > 256: - raise NamePartTooLongException - self.write_byte(length) - self.write_string(value) - - def write_name(self, name: str) -> None: - """ - Write names to packet - - 18.14. Name Compression - - When generating Multicast DNS messages, implementations SHOULD use - name compression wherever possible to compress the names of resource - records, by replacing some or all of the resource record name with a - compact two-byte reference to an appearance of that data somewhere - earlier in the message [RFC1035]. - """ - - # split name into each label - parts = name.split('.') - if not parts[-1]: - parts.pop() - - # construct each suffix - name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] - - # look for an existing name or suffix - for count, sub_name in enumerate(name_suffices): - if sub_name in self.names: - break - else: - count = len(name_suffices) - - # note the new names we are saving into the packet - name_length = len(name.encode('utf-8')) - for suffix in name_suffices[:count]: - self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 - - # write the new names out. - for part in parts[:count]: - self.write_utf(part) - - # if we wrote part of the name, create a pointer to the rest - if count != len(name_suffices): - # Found substring in packet, create pointer - index = self.names[name_suffices[count]] - self.write_byte((index >> 8) | 0xC0) - self.write_byte(index & 0xFF) - else: - # this is the end of a name - self.write_byte(0) - - def write_question(self, question: DNSQuestion) -> bool: - """Writes a question to the packet""" - start_data_length, start_size = len(self.data), self.size - self.write_name(question.name) - self.write_short(question.type) - self.write_short(question.class_) - return self._check_data_limit_or_rollback(start_data_length, start_size) - - def write_record(self, record: DNSRecord, now: float) -> bool: - """Writes a record (answer, authoritative answer, additional) to - the packet. Returns True on success, or False if we did not (either - because the packet was already finished or because the record does - not fit.""" - if self.state == self.State.finished: - return False - - start_data_length, start_size = len(self.data), self.size - self.write_name(record.name) - self.write_short(record.type) - if record.unique and self.multicast: - self.write_short(record.class_ | _CLASS_UNIQUE) - else: - self.write_short(record.class_) - if now == 0: - self.write_int(record.ttl) - else: - self.write_int(record.get_remaining_ttl(now)) - index = len(self.data) - - self.write_short(0) # Will get replaced with the actual size - record.write(self) - # Adjust size for the short we will write before this record - length = sum((len(d) for d in self.data[index + 1 :])) - # Here we replace the 0 length short we wrote - # before with the actual length - self.replace_short(index, length) - return self._check_data_limit_or_rollback(start_data_length, start_size) - - def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: - """Check data limit, if we go over, then rollback and return False.""" - len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL - self.allow_long = False - - if self.size <= len_limit: - return True - - log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) - - while len(self.data) > start_data_length: - self.data.pop() - self.size = start_size - - rollback_names = [name for name, idx in self.names.items() if idx >= start_size] - for name in rollback_names: - del self.names[name] - return False - - def packet(self) -> bytes: - """Returns a bytestring containing the first packet's bytes. - - Generally, you want to use packets() in case the response - does not fit in a single packet, but this exists for - backward compatibility.""" - packets = self.packets() - if len(packets) == 0: - return b'' - if len(packets[0]) > _MAX_MSG_ABSOLUTE: - QuietLogger.log_warning_once( - "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] - ) - return packets[0] - - def _write_questions_from_offset(self, questions_offset: int) -> int: - questions_written = 0 - for question in self.questions[questions_offset:]: - if not self.write_question(question): - break - questions_written += 1 - return questions_written - - def _write_answers_from_offset(self, answer_offset: int) -> int: - answers_written = 0 - for answer, time_ in self.answers[answer_offset:]: - if not self.write_record(answer, time_): - break - answers_written += 1 - return answers_written - - def _write_authorities_from_offset(self, authority_offset: int) -> int: - authorities_written = 0 - for authority in self.authorities[authority_offset:]: - if not self.write_record(authority, 0): - break - authorities_written += 1 - return authorities_written - - def _write_additionals_from_offset(self, additional_offset: int) -> int: - additionals_written = 0 - for additional in self.additionals[additional_offset:]: - if not self.write_record(additional, 0): - break - additionals_written += 1 - return additionals_written - - def _has_more_to_add( - self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int - ) -> bool: - """Check if all questions, answers, authority, and additionals have been written to the packet.""" - return ( - questions_offset < len(self.questions) - or answer_offset < len(self.answers) - or authority_offset < len(self.authorities) - or additional_offset < len(self.additionals) - ) - - def packets(self) -> List[bytes]: - """Returns a list of bytestrings containing the packets' bytes - - No further parts should be added to the packet once this - is done. The packets are each restricted to _MAX_MSG_TYPICAL - or less in length, except for the case of a single answer which - will be written out to a single oversized packet no more than - _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP - fragmentation potentially).""" - - if self.state == self.State.finished: - return self.packets_data - - questions_offset = 0 - answer_offset = 0 - authority_offset = 0 - additional_offset = 0 - # we have to at least write out the question - first_time = True - - while first_time or self._has_more_to_add( - questions_offset, answer_offset, authority_offset, additional_offset - ): - first_time = False - log.debug( - "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", - questions_offset, - answer_offset, - authority_offset, - additional_offset, - ) - log.debug( - "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", - len(self.questions), - len(self.answers), - len(self.authorities), - len(self.additionals), - ) - - questions_written = self._write_questions_from_offset(questions_offset) - answers_written = self._write_answers_from_offset(answer_offset) - authorities_written = self._write_authorities_from_offset(authority_offset) - additionals_written = self._write_additionals_from_offset(additional_offset) - - self.insert_short_at_start(additionals_written) - self.insert_short_at_start(authorities_written) - self.insert_short_at_start(answers_written) - self.insert_short_at_start(questions_written) - - questions_offset += questions_written - answer_offset += answers_written - authority_offset += authorities_written - additional_offset += additionals_written - log.debug( - "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", - questions_offset, - answer_offset, - authority_offset, - additional_offset, - ) - - if self.is_query() and self._has_more_to_add( - questions_offset, answer_offset, authority_offset, additional_offset - ): - # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - log.debug("Setting TC flag") - self.insert_short_at_start(self.flags | _FLAGS_TC) - else: - self.insert_short_at_start(self.flags) - - if self.multicast: - self.insert_short_at_start(0) - else: - self.insert_short_at_start(self.id) - - self.packets_data.append(b''.join(self.data)) - self.reset_for_next_packet() - - if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( - len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) - ) > 0: - log.warning("packets() made no progress adding records; returning") - break - self.state = self.State.finished - return self.packets_data - - -class DNSCache: - - """A cache of DNS entries""" - - def __init__(self) -> None: - self.cache = {} # type: Dict[str, List[DNSRecord]] - self.service_cache = {} # type: Dict[str, List[DNSRecord]] - - def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert last in list, get will return newest entry - # iteration will result in last update winning - self.cache.setdefault(entry.key, []).append(entry) - if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server, []).append(entry) - - def add_records(self, entries: Iterable[DNSRecord]) -> None: - """Add multiple records.""" - for entry in entries: - self.add(entry) - - def remove(self, entry: DNSRecord) -> None: - """Removes an entry.""" - if isinstance(entry, DNSService): - DNSCache.remove_key(self.service_cache, entry.server, entry) - DNSCache.remove_key(self.cache, entry.key, entry) - - def remove_records(self, entries: Iterable[DNSRecord]) -> None: - """Remove multiple records.""" - for entry in entries: - self.remove(entry) - - @staticmethod - def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: - """Forgiving remove of a cache key.""" - try: - cache[key].remove(entry) - if not cache[key]: - del cache[key] - except (KeyError, ValueError): - pass - - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Gets an entry by key. Will return None if there is no - matching entry.""" - for cached_entry in reversed(self.entries_with_name(entry.key)): - if entry.__eq__(cached_entry): - return cached_entry - return None - - def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets the first matching entry by details. Returns None if no entries match.""" - return self.get(DNSEntry(name, type_, class_)) - - def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: - """Gets all matching entries by details.""" - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] - - def entries_with_server(self, server: str) -> List[DNSRecord]: - """Returns a list of entries whose server matches the name.""" - return self.service_cache.get(server, [])[:] - - def entries_with_name(self, name: str) -> List[DNSRecord]: - """Returns a list of entries whose key matches the name.""" - return self.cache.get(name.lower(), [])[:] - - def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: - now = current_time_millis() - for record in reversed(self.entries_with_name(name)): - if ( - record.type == _TYPE_PTR - and not record.is_expired(now) - and cast(DNSPointer, record).alias == alias - ): - return record - return None - - def names(self) -> List[str]: - """Return a copy of the list of current cache names.""" - return list(self.cache) - - def expire(self, now: float) -> Iterable[DNSRecord]: - """Purge expired entries from the cache.""" - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self.remove(record) - yield record - - class Engine(threading.Thread): """An engine wraps read access to sockets, allowing objects that diff --git a/zeroconf/dns.py b/zeroconf/dns.py new file mode 100644 index 00000000..60d3c919 --- /dev/null +++ b/zeroconf/dns.py @@ -0,0 +1,1030 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import socket +import struct +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast + +from .const import ( + _CLASSES, + _CLASS_MASK, + _CLASS_UNIQUE, + _EXPIRE_FULL_TIME_PERCENT, + _EXPIRE_STALE_TIME_PERCENT, + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, + _MAX_MSG_ABSOLUTE, + _MAX_MSG_TYPICAL, + _TYPES, + _TYPE_A, + _TYPE_AAAA, + _TYPE_ANY, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) +from .exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException +from .logger import QuietLogger, log +from .utils.net import _is_v6_address +from .utils.struct import int2byte +from .utils.time import current_time_millis, millis_to_seconds + + +class DNSEntry: + + """A DNS entry""" + + def __init__(self, name: str, type_: int, class_: int) -> None: + self.key = name.lower() + self.name = name + self.type = type_ + self.class_ = class_ & _CLASS_MASK + self.unique = (class_ & _CLASS_UNIQUE) != 0 + + def __eq__(self, other: Any) -> bool: + """Equality test on key (lowercase name), type, and class""" + return ( + self.key == other.key + and self.type == other.type + and self.class_ == other.class_ + and isinstance(other, DNSEntry) + ) + + @staticmethod + def get_class_(class_: int) -> str: + """Class accessor""" + return _CLASSES.get(class_, "?(%s)" % class_) + + @staticmethod + def get_type(t: int) -> str: + """Type accessor""" + return _TYPES.get(t, "?(%s)" % t) + + def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: + """String representation with additional information""" + result = "%s[%s,%s" % (hdr, self.get_type(self.type), self.get_class_(self.class_)) + if self.unique: + result += "-unique," + else: + result += "," + result += self.name + if other is not None: + result += "]=%s" % cast(Any, other) + else: + result += "]" + return result + + +class DNSQuestion(DNSEntry): + + """A DNS question entry""" + + def __init__(self, name: str, type_: int, class_: int) -> None: + DNSEntry.__init__(self, name, type_, class_) + + def answered_by(self, rec: 'DNSRecord') -> bool: + """Returns true if the question is answered by the record""" + return ( + self.class_ == rec.class_ + and (self.type == rec.type or self.type == _TYPE_ANY) + and self.name == rec.name + ) + + def __repr__(self) -> str: + """String representation""" + return DNSEntry.entry_to_string(self, "question", None) + + +class DNSRecord(DNSEntry): + + """A DNS record - like a DNS entry, but has a TTL""" + + # TODO: Switch to just int ttl + def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: + DNSEntry.__init__(self, name, type_, class_) + self.ttl = ttl + self.created = current_time_millis() + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) + self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) + + def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def suppressed_by(self, msg: 'DNSIncoming') -> bool: + """Returns true if any answer in a message can suffice for the + information held in this record.""" + for record in msg.answers: + if self.suppressed_by_answer(record): + return True + return False + + def suppressed_by_answer(self, other: 'DNSRecord') -> bool: + """Returns true if another record has same name, type and class, + and if its TTL is at least half of this record's.""" + return self == other and other.ttl > (self.ttl / 2) + + def get_expiration_time(self, percent: int) -> float: + """Returns the time at which this record will have expired + by a certain percentage.""" + return self.created + (percent * self.ttl * 10) + + # TODO: Switch to just int here + def get_remaining_ttl(self, now: float) -> Union[int, float]: + """Returns the remaining TTL in seconds.""" + return max(0, millis_to_seconds(self._expiration_time - now)) + + def is_expired(self, now: float) -> bool: + """Returns true if this record has expired.""" + return self._expiration_time <= now + + def is_stale(self, now: float) -> bool: + """Returns true if this record is at least half way expired.""" + return self._stale_time <= now + + def reset_ttl(self, other: 'DNSRecord') -> None: + """Sets this record's TTL and created time to that of + another record.""" + self.created = other.created + self.ttl = other.ttl + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) + self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) + + def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use + """Abstract method""" + raise AbstractMethodException + + def to_string(self, other: Union[bytes, str]) -> str: + """String representation with additional information""" + arg = "%s/%s,%s" % (self.ttl, int(self.get_remaining_ttl(current_time_millis())), cast(Any, other)) + return DNSEntry.entry_to_string(self, "record", arg) + + +class DNSAddress(DNSRecord): + + """A DNS address record""" + + def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: + DNSRecord.__init__(self, name, type_, class_, ttl) + self.address = address + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.address) + + def __eq__(self, other: Any) -> bool: + """Tests equality on address""" + return ( + isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address + ) + + def __repr__(self) -> str: + """String representation""" + try: + return self.to_string( + socket.inet_ntop( + socket.AF_INET6 if _is_v6_address(self.address) else socket.AF_INET, self.address + ) + ) + except (ValueError, OSError): + return self.to_string(str(self.address)) + + +class DNSHinfo(DNSRecord): + + """A DNS host information record""" + + def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: + DNSRecord.__init__(self, name, type_, class_, ttl) + self.cpu = cpu + self.os = os + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_character_string(self.cpu.encode('utf-8')) + out.write_character_string(self.os.encode('utf-8')) + + def __eq__(self, other: Any) -> bool: + """Tests equality on cpu and os""" + return ( + isinstance(other, DNSHinfo) + and DNSEntry.__eq__(self, other) + and self.cpu == other.cpu + and self.os == other.os + ) + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.cpu + " " + self.os) + + +class DNSPointer(DNSRecord): + + """A DNS pointer record""" + + def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: + DNSRecord.__init__(self, name, type_, class_, ttl) + self.alias = alias + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_name(self.alias) + + def __eq__(self, other: Any) -> bool: + """Tests equality on alias""" + return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.alias) + + +class DNSText(DNSRecord): + + """A DNS text record""" + + def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: + assert isinstance(text, (bytes, type(None))) + DNSRecord.__init__(self, name, type_, class_, ttl) + self.text = text + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_string(self.text) + + def __eq__(self, other: Any) -> bool: + """Tests equality on text""" + return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) + + def __repr__(self) -> str: + """String representation""" + if len(self.text) > 10: + return self.to_string(self.text[:7]) + "..." + return self.to_string(self.text) + + +class DNSService(DNSRecord): + + """A DNS service record""" + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: Union[float, int], + priority: int, + weight: int, + port: int, + server: str, + ) -> None: + DNSRecord.__init__(self, name, type_, class_, ttl) + self.priority = priority + self.weight = weight + self.port = port + self.server = server + + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet""" + out.write_short(self.priority) + out.write_short(self.weight) + out.write_short(self.port) + out.write_name(self.server) + + def __eq__(self, other: Any) -> bool: + """Tests equality on priority, weight, port and server""" + return ( + isinstance(other, DNSService) + and self.priority == other.priority + and self.weight == other.weight + and self.port == other.port + and self.server == other.server + and DNSEntry.__eq__(self, other) + ) + + def __repr__(self) -> str: + """String representation""" + return self.to_string("%s:%s" % (self.server, self.port)) + + +class DNSMessage: + """A base class for DNS messages.""" + + def __init__(self, flags: int) -> None: + """Construct a DNS message.""" + self.flags = flags + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + +class DNSIncoming(DNSMessage, QuietLogger): + + """Object representation of an incoming DNS packet""" + + def __init__(self, data: bytes) -> None: + """Constructor from string holding bytes of packet""" + super().__init__(0) + self.offset = 0 + self.data = data + self.questions = [] # type: List[DNSQuestion] + self.answers = [] # type: List[DNSRecord] + self.id = 0 + self.num_questions = 0 + self.num_answers = 0 + self.num_authorities = 0 + self.num_additionals = 0 + self.valid = False + + try: + self.read_header() + self.read_questions() + self.read_others() + self.valid = True + + except (IndexError, struct.error, IncomingDecodeError): + self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'n_q=%s' % self.num_questions, + 'n_ans=%s' % self.num_answers, + 'n_auth=%s' % self.num_authorities, + 'n_add=%s' % self.num_additionals, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + ] + ) + + def unpack(self, format_: bytes) -> tuple: + length = struct.calcsize(format_) + info = struct.unpack(format_, self.data[self.offset : self.offset + length]) + self.offset += length + return info + + def read_header(self) -> None: + """Reads header portion of packet""" + ( + self.id, + self.flags, + self.num_questions, + self.num_answers, + self.num_authorities, + self.num_additionals, + ) = self.unpack(b'!6H') + + def read_questions(self) -> None: + """Reads questions section of packet""" + for _ in range(self.num_questions): + name = self.read_name() + type_, class_ = self.unpack(b'!HH') + + question = DNSQuestion(name, type_, class_) + self.questions.append(question) + + # def read_int(self): + # """Reads an integer from the packet""" + # return self.unpack(b'!I')[0] + + def read_character_string(self) -> bytes: + """Reads a character string from the packet""" + length = self.data[self.offset] + self.offset += 1 + return self.read_string(length) + + def read_string(self, length: int) -> bytes: + """Reads a string of a given length from the packet""" + info = self.data[self.offset : self.offset + length] + self.offset += length + return info + + def read_unsigned_short(self) -> int: + """Reads an unsigned short from the packet""" + return cast(int, self.unpack(b'!H')[0]) + + def read_others(self) -> None: + """Reads the answers, authorities and additionals section of the + packet""" + n = self.num_answers + self.num_authorities + self.num_additionals + for _ in range(n): + domain = self.read_name() + type_, class_, ttl, length = self.unpack(b'!HHiH') + rec = None # type: Optional[DNSRecord] + if type_ == _TYPE_A: + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) + elif type_ in (_TYPE_CNAME, _TYPE_PTR): + rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) + elif type_ == _TYPE_TXT: + rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) + elif type_ == _TYPE_SRV: + rec = DNSService( + domain, + type_, + class_, + ttl, + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_name(), + ) + elif type_ == _TYPE_HINFO: + rec = DNSHinfo( + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), + ) + elif type_ == _TYPE_AAAA: + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) + else: + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + + if rec is not None: + self.answers.append(rec) + + def read_utf(self, offset: int, length: int) -> str: + """Reads a UTF-8 string of a given length from the packet""" + return str(self.data[offset : offset + length], 'utf-8', 'replace') + + def read_name(self) -> str: + """Reads a domain name from the packet""" + result = '' + off = self.offset + next_ = -1 + first = off + + while True: + length = self.data[off] + off += 1 + if length == 0: + break + t = length & 0xC0 + if t == 0x00: + result += self.read_utf(off, length) + '.' + off += length + elif t == 0xC0: + if next_ < 0: + next_ = off + 1 + off = ((length & 0x3F) << 8) | self.data[off] + if off >= first: + raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) + first = off + else: + raise IncomingDecodeError("Bad domain name at %s" % (off,)) + + if next_ >= 0: + self.offset = next_ + else: + self.offset = off + + return result + + +class DNSOutgoing(DNSMessage): + + """Object representation of an outgoing packet""" + + def __init__(self, flags: int, multicast: bool = True) -> None: + super().__init__(flags) + self.finished = False + self.id = 0 + self.multicast = multicast + self.packets_data = [] # type: List[bytes] + + # these 3 are per-packet -- see also reset_for_next_packet() + self.names = {} # type: Dict[str, int] + self.data = [] # type: List[bytes] + self.size = 12 + self.allow_long = True + + self.state = self.State.init + + self.questions = [] # type: List[DNSQuestion] + self.answers = [] # type: List[Tuple[DNSRecord, float]] + self.authorities = [] # type: List[DNSPointer] + self.additionals = [] # type: List[DNSRecord] + + def reset_for_next_packet(self) -> None: + self.names = {} + self.data = [] + self.size = 12 + self.allow_long = True + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'multicast=%s' % self.multicast, + 'flags=%s' % self.flags, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + 'authorities=%s' % self.authorities, + 'additionals=%s' % self.additionals, + ] + ) + + class State(enum.Enum): + init = 0 + finished = 1 + + def add_question(self, record: DNSQuestion) -> None: + """Adds a question""" + self.questions.append(record) + + def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: + """Adds an answer""" + if not record.suppressed_by(inp): + self.add_answer_at_time(record, 0) + + def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: + """Adds an answer if it does not expire by a certain time""" + if record is not None: + if now == 0 or not record.is_expired(now): + self.answers.append((record, now)) + + def add_authorative_answer(self, record: DNSPointer) -> None: + """Adds an authoritative answer""" + self.authorities.append(record) + + def add_additional_answer(self, record: DNSRecord) -> None: + """Adds an additional answer + + From: RFC 6763, DNS-Based Service Discovery, February 2013 + + 12. DNS Additional Record Generation + + DNS has an efficiency feature whereby a DNS server may place + additional records in the additional section of the DNS message. + These additional records are records that the client did not + explicitly request, but the server has reasonable grounds to expect + that the client might request them shortly, so including them can + save the client from having to issue additional queries. + + This section recommends which additional records SHOULD be generated + to improve network efficiency, for both Unicast and Multicast DNS-SD + responses. + + 12.1. PTR Records + + When including a DNS-SD Service Instance Enumeration or Selective + Instance Enumeration (subtype) PTR record in a response packet, the + server/responder SHOULD include the following additional records: + + o The SRV record(s) named in the PTR rdata. + o The TXT record(s) named in the PTR rdata. + o All address records (type "A" and "AAAA") named in the SRV rdata. + + 12.2. SRV Records + + When including an SRV record in a response packet, the + server/responder SHOULD include the following additional records: + + o All address records (type "A" and "AAAA") named in the SRV rdata. + + """ + self.additionals.append(record) + + def add_question_or_one_cache( + self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached.""" + cached_entry = cache.get_by_details(name, type_, class_) + if not cached_entry: + self.add_question(DNSQuestion(name, type_, class_)) + else: + self.add_answer_at_time(cached_entry, now) + + def add_question_or_all_cache( + self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached. + This is currently only used for IPv6 addresses. + """ + cached_entries = cache.get_all_by_details(name, type_, class_) + if not cached_entries: + self.add_question(DNSQuestion(name, type_, class_)) + return + for cached_entry in cached_entries: + self.add_answer_at_time(cached_entry, now) + + def pack(self, format_: Union[bytes, str], value: Any) -> None: + self.data.append(struct.pack(format_, value)) + self.size += struct.calcsize(format_) + + def write_byte(self, value: int) -> None: + """Writes a single byte to the packet""" + self.pack(b'!c', int2byte(value)) + + def insert_short_at_start(self, value: int) -> None: + """Inserts an unsigned short at the start of the packet""" + self.data.insert(0, struct.pack(b'!H', value)) + + def replace_short(self, index: int, value: int) -> None: + """Replaces an unsigned short in a certain position in the packet""" + self.data[index] = struct.pack(b'!H', value) + + def write_short(self, value: int) -> None: + """Writes an unsigned short to the packet""" + self.pack(b'!H', value) + + def write_int(self, value: Union[float, int]) -> None: + """Writes an unsigned integer to the packet""" + self.pack(b'!I', int(value)) + + def write_string(self, value: bytes) -> None: + """Writes a string to the packet""" + assert isinstance(value, bytes) + self.data.append(value) + self.size += len(value) + + def write_utf(self, s: str) -> None: + """Writes a UTF-8 string of a given length to the packet""" + utfstr = s.encode('utf-8') + length = len(utfstr) + if length > 64: + raise NamePartTooLongException + self.write_byte(length) + self.write_string(utfstr) + + def write_character_string(self, value: bytes) -> None: + assert isinstance(value, bytes) + length = len(value) + if length > 256: + raise NamePartTooLongException + self.write_byte(length) + self.write_string(value) + + def write_name(self, name: str) -> None: + """ + Write names to packet + + 18.14. Name Compression + + When generating Multicast DNS messages, implementations SHOULD use + name compression wherever possible to compress the names of resource + records, by replacing some or all of the resource record name with a + compact two-byte reference to an appearance of that data somewhere + earlier in the message [RFC1035]. + """ + + # split name into each label + parts = name.split('.') + if not parts[-1]: + parts.pop() + + # construct each suffix + name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] + + # look for an existing name or suffix + for count, sub_name in enumerate(name_suffices): + if sub_name in self.names: + break + else: + count = len(name_suffices) + + # note the new names we are saving into the packet + name_length = len(name.encode('utf-8')) + for suffix in name_suffices[:count]: + self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 + + # write the new names out. + for part in parts[:count]: + self.write_utf(part) + + # if we wrote part of the name, create a pointer to the rest + if count != len(name_suffices): + # Found substring in packet, create pointer + index = self.names[name_suffices[count]] + self.write_byte((index >> 8) | 0xC0) + self.write_byte(index & 0xFF) + else: + # this is the end of a name + self.write_byte(0) + + def write_question(self, question: DNSQuestion) -> bool: + """Writes a question to the packet""" + start_data_length, start_size = len(self.data), self.size + self.write_name(question.name) + self.write_short(question.type) + self.write_short(question.class_) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def write_record(self, record: DNSRecord, now: float) -> bool: + """Writes a record (answer, authoritative answer, additional) to + the packet. Returns True on success, or False if we did not (either + because the packet was already finished or because the record does + not fit.""" + if self.state == self.State.finished: + return False + + start_data_length, start_size = len(self.data), self.size + self.write_name(record.name) + self.write_short(record.type) + if record.unique and self.multicast: + self.write_short(record.class_ | _CLASS_UNIQUE) + else: + self.write_short(record.class_) + if now == 0: + self.write_int(record.ttl) + else: + self.write_int(record.get_remaining_ttl(now)) + index = len(self.data) + + self.write_short(0) # Will get replaced with the actual size + record.write(self) + # Adjust size for the short we will write before this record + length = sum((len(d) for d in self.data[index + 1 :])) + # Here we replace the 0 length short we wrote + # before with the actual length + self.replace_short(index, length) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: + """Check data limit, if we go over, then rollback and return False.""" + len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL + self.allow_long = False + + if self.size <= len_limit: + return True + + log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) + + while len(self.data) > start_data_length: + self.data.pop() + self.size = start_size + + rollback_names = [name for name, idx in self.names.items() if idx >= start_size] + for name in rollback_names: + del self.names[name] + return False + + def packet(self) -> bytes: + """Returns a bytestring containing the first packet's bytes. + + Generally, you want to use packets() in case the response + does not fit in a single packet, but this exists for + backward compatibility.""" + packets = self.packets() + if len(packets) == 0: + return b'' + if len(packets[0]) > _MAX_MSG_ABSOLUTE: + QuietLogger.log_warning_once( + "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] + ) + return packets[0] + + def _write_questions_from_offset(self, questions_offset: int) -> int: + questions_written = 0 + for question in self.questions[questions_offset:]: + if not self.write_question(question): + break + questions_written += 1 + return questions_written + + def _write_answers_from_offset(self, answer_offset: int) -> int: + answers_written = 0 + for answer, time_ in self.answers[answer_offset:]: + if not self.write_record(answer, time_): + break + answers_written += 1 + return answers_written + + def _write_authorities_from_offset(self, authority_offset: int) -> int: + authorities_written = 0 + for authority in self.authorities[authority_offset:]: + if not self.write_record(authority, 0): + break + authorities_written += 1 + return authorities_written + + def _write_additionals_from_offset(self, additional_offset: int) -> int: + additionals_written = 0 + for additional in self.additionals[additional_offset:]: + if not self.write_record(additional, 0): + break + additionals_written += 1 + return additionals_written + + def _has_more_to_add( + self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int + ) -> bool: + """Check if all questions, answers, authority, and additionals have been written to the packet.""" + return ( + questions_offset < len(self.questions) + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ) + + def packets(self) -> List[bytes]: + """Returns a list of bytestrings containing the packets' bytes + + No further parts should be added to the packet once this + is done. The packets are each restricted to _MAX_MSG_TYPICAL + or less in length, except for the case of a single answer which + will be written out to a single oversized packet no more than + _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP + fragmentation potentially).""" + + if self.state == self.State.finished: + return self.packets_data + + questions_offset = 0 + answer_offset = 0 + authority_offset = 0 + additional_offset = 0 + # we have to at least write out the question + first_time = True + + while first_time or self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + first_time = False + log.debug( + "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + log.debug( + "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", + len(self.questions), + len(self.answers), + len(self.authorities), + len(self.additionals), + ) + + questions_written = self._write_questions_from_offset(questions_offset) + answers_written = self._write_answers_from_offset(answer_offset) + authorities_written = self._write_authorities_from_offset(authority_offset) + additionals_written = self._write_additionals_from_offset(additional_offset) + + self.insert_short_at_start(additionals_written) + self.insert_short_at_start(authorities_written) + self.insert_short_at_start(answers_written) + self.insert_short_at_start(questions_written) + + questions_offset += questions_written + answer_offset += answers_written + authority_offset += authorities_written + additional_offset += additionals_written + log.debug( + "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + + if self.is_query() and self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + log.debug("Setting TC flag") + self.insert_short_at_start(self.flags | _FLAGS_TC) + else: + self.insert_short_at_start(self.flags) + + if self.multicast: + self.insert_short_at_start(0) + else: + self.insert_short_at_start(self.id) + + self.packets_data.append(b''.join(self.data)) + self.reset_for_next_packet() + + if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( + len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) + ) > 0: + log.warning("packets() made no progress adding records; returning") + break + self.state = self.State.finished + return self.packets_data + + +class DNSCache: + + """A cache of DNS entries""" + + def __init__(self) -> None: + self.cache = {} # type: Dict[str, List[DNSRecord]] + self.service_cache = {} # type: Dict[str, List[DNSRecord]] + + def add(self, entry: DNSRecord) -> None: + """Adds an entry""" + # Insert last in list, get will return newest entry + # iteration will result in last update winning + self.cache.setdefault(entry.key, []).append(entry) + if isinstance(entry, DNSService): + self.service_cache.setdefault(entry.server, []).append(entry) + + def add_records(self, entries: Iterable[DNSRecord]) -> None: + """Add multiple records.""" + for entry in entries: + self.add(entry) + + def remove(self, entry: DNSRecord) -> None: + """Removes an entry.""" + if isinstance(entry, DNSService): + DNSCache.remove_key(self.service_cache, entry.server, entry) + DNSCache.remove_key(self.cache, entry.key, entry) + + def remove_records(self, entries: Iterable[DNSRecord]) -> None: + """Remove multiple records.""" + for entry in entries: + self.remove(entry) + + @staticmethod + def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: + """Forgiving remove of a cache key.""" + try: + cache[key].remove(entry) + if not cache[key]: + del cache[key] + except (KeyError, ValueError): + pass + + def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry.""" + for cached_entry in reversed(self.entries_with_name(entry.key)): + if entry.__eq__(cached_entry): + return cached_entry + return None + + def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: + """Gets the first matching entry by details. Returns None if no entries match.""" + return self.get(DNSEntry(name, type_, class_)) + + def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details.""" + match_entry = DNSEntry(name, type_, class_) + return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] + + def entries_with_server(self, server: str) -> List[DNSRecord]: + """Returns a list of entries whose server matches the name.""" + return self.service_cache.get(server, [])[:] + + def entries_with_name(self, name: str) -> List[DNSRecord]: + """Returns a list of entries whose key matches the name.""" + return self.cache.get(name.lower(), [])[:] + + def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + now = current_time_millis() + for record in reversed(self.entries_with_name(name)): + if ( + record.type == _TYPE_PTR + and not record.is_expired(now) + and cast(DNSPointer, record).alias == alias + ): + return record + return None + + def names(self) -> List[str]: + """Return a copy of the list of current cache names.""" + return list(self.cache) + + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache.""" + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record From b4814f5f216cd4072bafdd7dd1e68ee522f329c2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 13:43:02 -1000 Subject: [PATCH 236/608] Move service_type_name to zeroconf.utils.name (#543) --- zeroconf/__init__.py | 122 +------------------------------- zeroconf/utils/name.py | 153 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 121 deletions(-) create mode 100644 zeroconf/utils/name.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 2ae88431..6f5d9c9c 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -108,6 +108,7 @@ ServiceNameAlreadyRegistered, ) from .logger import QuietLogger, log +from .utils.name import service_type_name from .utils.net import ( # noqa # import needed for backwards compat add_multicast_member, can_send_to, @@ -156,127 +157,6 @@ # utility functions -def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches - """ - Validate a fully qualified service name, instance or subtype. [rfc6763] - - Returns fully qualified service name. - - Domain names used by mDNS-SD take the following forms: - - . <_tcp|_udp> . local. - . . <_tcp|_udp> . local. - ._sub . . <_tcp|_udp> . local. - - 1) must end with 'local.' - - This is true because we are implementing mDNS and since the 'm' means - multi-cast, the 'local.' domain is mandatory. - - 2) local is preceded with either '_udp.' or '_tcp.' unless - strict is False - - 3) service name precedes <_tcp|_udp> unless - strict is False - - The rules for Service Names [RFC6335] state that they may be no more - than fifteen characters long (not counting the mandatory underscore), - consisting of only letters, digits, and hyphens, must begin and end - with a letter or digit, must not contain consecutive hyphens, and - must contain at least one letter. - - The instance name and sub type may be up to 63 bytes. - - The portion of the Service Instance Name is a user- - friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It - MUST NOT contain ASCII control characters (byte values 0x00-0x1F and - 0x7F) [RFC20] but otherwise is allowed to contain any characters, - without restriction, including spaces, uppercase, lowercase, - punctuation -- including dots -- accented characters, non-Roman text, - and anything else that may be represented using Net-Unicode. - - :param type_: Type, SubType or service name to validate - :return: fully qualified service name (eg: _http._tcp.local.) - """ - - if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): - remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') - trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] - has_protocol = True - elif strict: - raise BadTypeInNameException( - "Type '%s' must end with '%s' or '%s'" - % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER) - ) - elif type_.endswith(_LOCAL_TRAILER): - remaining = type_[: -len(_LOCAL_TRAILER)].split('.') - trailer = type_[-len(_LOCAL_TRAILER) + 1 :] - has_protocol = False - else: - raise BadTypeInNameException("Type '%s' must end with '%s'" % (type_, _LOCAL_TRAILER)) - - if strict or has_protocol: - service_name = remaining.pop() - if not service_name: - raise BadTypeInNameException("No Service name found") - - if len(remaining) == 1 and len(remaining[0]) == 0: - raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) - - if service_name[0] != '_': - raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name) - - test_service_name = service_name[1:] - - if len(test_service_name) > 15: - raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) - - if '--' in test_service_name: - raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name) - - if '-' in (test_service_name[0], test_service_name[-1]): - raise BadTypeInNameException( - "Service name (%s) may not start or end with '-'" % test_service_name - ) - - if not _HAS_A_TO_Z.search(test_service_name): - raise BadTypeInNameException( - "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name - ) - - allowed_characters_re = ( - _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE - ) - - if not allowed_characters_re.search(test_service_name): - raise BadTypeInNameException( - "Service name (%s) must contain only these characters: " - "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')") - ) - else: - service_name = '' - - if remaining and remaining[-1] == '_sub': - remaining.pop() - if len(remaining) == 0 or len(remaining[0]) == 0: - raise BadTypeInNameException("_sub requires a subtype name") - - if len(remaining) > 1: - remaining = ['.'.join(remaining)] - - if remaining: - length = len(remaining[0].encode('utf-8')) - if length > 63: - raise BadTypeInNameException("Too long: '%s'" % remaining[0]) - - if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]): - raise BadTypeInNameException( - "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] - ) - - return service_name + trailer - - def instance_name_from_service_info(info: "ServiceInfo") -> str: """Calculate the instance name from the ServiceInfo.""" # This is kind of funky because of the subtype based tests diff --git a/zeroconf/utils/name.py b/zeroconf/utils/name.py new file mode 100644 index 00000000..65713eb0 --- /dev/null +++ b/zeroconf/utils/name.py @@ -0,0 +1,153 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from ..const import ( + _HAS_ASCII_CONTROL_CHARS, + _HAS_A_TO_Z, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN, + _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE, + _LOCAL_TRAILER, + _NONTCP_PROTOCOL_LOCAL_TRAILER, + _TCP_PROTOCOL_LOCAL_TRAILER, +) +from ..exceptions import BadTypeInNameException + + +def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches + """ + Validate a fully qualified service name, instance or subtype. [rfc6763] + + Returns fully qualified service name. + + Domain names used by mDNS-SD take the following forms: + + . <_tcp|_udp> . local. + . . <_tcp|_udp> . local. + ._sub . . <_tcp|_udp> . local. + + 1) must end with 'local.' + + This is true because we are implementing mDNS and since the 'm' means + multi-cast, the 'local.' domain is mandatory. + + 2) local is preceded with either '_udp.' or '_tcp.' unless + strict is False + + 3) service name precedes <_tcp|_udp> unless + strict is False + + The rules for Service Names [RFC6335] state that they may be no more + than fifteen characters long (not counting the mandatory underscore), + consisting of only letters, digits, and hyphens, must begin and end + with a letter or digit, must not contain consecutive hyphens, and + must contain at least one letter. + + The instance name and sub type may be up to 63 bytes. + + The portion of the Service Instance Name is a user- + friendly name consisting of arbitrary Net-Unicode text [RFC5198]. It + MUST NOT contain ASCII control characters (byte values 0x00-0x1F and + 0x7F) [RFC20] but otherwise is allowed to contain any characters, + without restriction, including spaces, uppercase, lowercase, + punctuation -- including dots -- accented characters, non-Roman text, + and anything else that may be represented using Net-Unicode. + + :param type_: Type, SubType or service name to validate + :return: fully qualified service name (eg: _http._tcp.local.) + """ + + if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): + remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_TCP_PROTOCOL_LOCAL_TRAILER) :] + has_protocol = True + elif strict: + raise BadTypeInNameException( + "Type '%s' must end with '%s' or '%s'" + % (type_, _TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER) + ) + elif type_.endswith(_LOCAL_TRAILER): + remaining = type_[: -len(_LOCAL_TRAILER)].split('.') + trailer = type_[-len(_LOCAL_TRAILER) + 1 :] + has_protocol = False + else: + raise BadTypeInNameException("Type '%s' must end with '%s'" % (type_, _LOCAL_TRAILER)) + + if strict or has_protocol: + service_name = remaining.pop() + if not service_name: + raise BadTypeInNameException("No Service name found") + + if len(remaining) == 1 and len(remaining[0]) == 0: + raise BadTypeInNameException("Type '%s' must not start with '.'" % type_) + + if service_name[0] != '_': + raise BadTypeInNameException("Service name (%s) must start with '_'" % service_name) + + test_service_name = service_name[1:] + + if len(test_service_name) > 15: + raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) + + if '--' in test_service_name: + raise BadTypeInNameException("Service name (%s) must not contain '--'" % test_service_name) + + if '-' in (test_service_name[0], test_service_name[-1]): + raise BadTypeInNameException( + "Service name (%s) may not start or end with '-'" % test_service_name + ) + + if not _HAS_A_TO_Z.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain at least one letter (eg: 'A-Z')" % test_service_name + ) + + allowed_characters_re = ( + _HAS_ONLY_A_TO_Z_NUM_HYPHEN if strict else _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE + ) + + if not allowed_characters_re.search(test_service_name): + raise BadTypeInNameException( + "Service name (%s) must contain only these characters: " + "A-Z, a-z, 0-9, hyphen ('-')%s" % (test_service_name, "" if strict else ", underscore ('_')") + ) + else: + service_name = '' + + if remaining and remaining[-1] == '_sub': + remaining.pop() + if len(remaining) == 0 or len(remaining[0]) == 0: + raise BadTypeInNameException("_sub requires a subtype name") + + if len(remaining) > 1: + remaining = ['.'.join(remaining)] + + if remaining: + length = len(remaining[0].encode('utf-8')) + if length > 63: + raise BadTypeInNameException("Too long: '%s'" % remaining[0]) + + if _HAS_ASCII_CONTROL_CHARS.search(remaining[0]): + raise BadTypeInNameException( + "Ascii control character 0x00-0x1F and 0x7F illegal in '%s'" % remaining[0] + ) + + return service_name + trailer From bdea21c0a61b6d9d0af3810f18dbc2fc2364c484 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 14:36:56 -1000 Subject: [PATCH 237/608] Breakout service classes into zeroconf.services (#544) --- tests/test_init.py | 221 ------------ tests/test_services.py | 266 ++++++++++++++ zeroconf/__init__.py | 720 +------------------------------------- zeroconf/services.py | 771 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1047 insertions(+), 931 deletions(-) create mode 100644 tests/test_services.py create mode 100644 zeroconf/services.py diff --git a/tests/test_init.py b/tests/test_init.py index f0092ea4..a740cab4 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -2210,167 +2210,6 @@ def _mock_get_expiration_time(self, percent): zeroconf.close() -def test_backoff(): - got_query = Event() - - type_ = "_http._tcp.local." - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to monkey patch the zeroconf send to check query transmission - old_send = zeroconf_browser.send - - time_offset = 0.0 - start_time = time.time() * 1000 - initial_query_interval = r._BROWSER_TIME / 1000 - - def current_time_millis(): - """Current system time in milliseconds""" - return start_time + time_offset * 1000 - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - got_query.set() - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) - - # monkey patch the zeroconf current_time_millis - r.current_time_millis = current_time_millis - - # monkey patch the backoff limit to prevent test running forever - r._BROWSER_BACKOFF_LIMIT = 10 # seconds - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - try: - # Test that queries are sent at increasing intervals - sleep_count = 0 - next_query_interval = 0.0 - expected_query_time = 0.0 - while True: - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - if time_offset == expected_query_time: - assert got_query.is_set() - got_query.clear() - if next_query_interval == r._BROWSER_BACKOFF_LIMIT: - # Only need to test up to the point where we've seen a query - # after the backoff limit has been hit - break - elif next_query_interval == 0: - next_query_interval = initial_query_interval - expected_query_time = initial_query_interval - else: - next_query_interval = min(2 * next_query_interval, r._BROWSER_BACKOFF_LIMIT) - expected_query_time += next_query_interval - else: - assert not got_query.is_set() - time_offset += initial_query_interval - - finally: - browser.cancel() - zeroconf_browser.close() - - -def test_integration(): - service_added = Event() - service_removed = Event() - unexpected_ttl = Event() - got_query = Event() - - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - - def on_service_state_change(zeroconf, service_type, state_change, name): - if name == registration_name: - if state_change is ServiceStateChange.Added: - service_added.set() - elif state_change is ServiceStateChange.Removed: - service_removed.set() - - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to monkey patch the zeroconf send to check packet sizes - old_send = zeroconf_browser.send - - time_offset = 0.0 - - def current_time_millis(): - """Current system time in milliseconds""" - return time.time() * 1000 + time_offset * 1000 - - expected_ttl = r._DNS_HOST_TTL - - nbr_answers = 0 - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - pout = r.DNSIncoming(out.packet()) - nonlocal nbr_answers - for answer in pout.answers: - nbr_answers += 1 - if not answer.ttl > expected_ttl / 2: - unexpected_ttl.set() - - got_query.set() - old_send(out, addr=addr, port=port) - - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) - - # monkey patch the zeroconf current_time_millis - r.current_time_millis = current_time_millis - - # monkey patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - r._BROWSER_BACKOFF_LIMIT = int(expected_ttl / 4) - - service_added = Event() - service_removed = Event() - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - zeroconf_registrar.register_service(info) - - try: - service_added.wait(1) - assert service_added.is_set() - - # Test that we receive queries containing answers only if the remaining TTL - # is greater than half the original TTL - sleep_count = 0 - test_iterations = 50 - while nbr_answers < test_iterations: - # Increase simulated time shift by 1/4 of the TTL in seconds - time_offset += expected_ttl / 4 - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - got_query.clear() - # Prevent the test running indefinitely in an error condition - assert sleep_count < test_iterations * 4 - assert not unexpected_ttl.is_set() - - # Don't remove service, allow close() to cleanup - - finally: - zeroconf_registrar.close() - service_removed.wait(1) - assert service_removed.is_set() - browser.cancel() - zeroconf_browser.close() - - def test_multiple_addresses(): type_ = "_http._tcp.local." registration_name = "xxxyyy.%s" % type_ @@ -2748,66 +2587,6 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.close() -def test_legacy_record_update_listener(): - """Test a RecordUpdateListener that does not implement update_records.""" - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - with pytest.raises(RuntimeError): - r.RecordUpdateListener().update_record( - zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - ) - - updates = [] - - class LegacyRecordUpdateListener(r.RecordUpdateListener): - """A RecordUpdateListener that does not implement update_records.""" - - def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: - nonlocal updates - updates.append(record) - - listener = LegacyRecordUpdateListener() - - zc.add_listener(listener, None) - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - # start a browser - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - - zc.register_service(info_service) - - zc.wait(1) - - browser.cancel() - - assert len(updates) - assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 - - zc.remove_listener(listener) - # Removing a second time should not throw - zc.remove_listener(listener) - - zc.close() - - def test_autodetect_ip_version(): """Tests for auto detecting IPVersion based on interface ips.""" assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only diff --git a/tests/test_services.py b/tests/test_services.py new file mode 100644 index 00000000..d931d5c0 --- /dev/null +++ b/tests/test_services.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf.services. """ + +import logging +import socket +import threading +import time +from threading import Event + +import pytest + +import zeroconf as r +import zeroconf.services as s +from zeroconf import ( + ServiceBrowser, + ServiceInfo, + ServiceStateChange, + Zeroconf, +) + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +def test_backoff(): + got_query = Event() + + type_ = "_http._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to monkey patch the zeroconf send to check query transmission + old_send = zeroconf_browser.send + + time_offset = 0.0 + start_time = time.time() * 1000 + initial_query_interval = s._BROWSER_TIME / 1000 + + def current_time_millis(): + """Current system time in milliseconds""" + return start_time + time_offset * 1000 + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + got_query.set() + old_send(out, addr=addr, port=port) + + # monkey patch the zeroconf send + setattr(zeroconf_browser, "send", send) + + # monkey patch the zeroconf current_time_millis + s.current_time_millis = current_time_millis + + # monkey patch the backoff limit to prevent test running forever + s._BROWSER_BACKOFF_LIMIT = 10 # seconds + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + try: + # Test that queries are sent at increasing intervals + sleep_count = 0 + next_query_interval = 0.0 + expected_query_time = 0.0 + while True: + zeroconf_browser.notify_all() + sleep_count += 1 + got_query.wait(0.1) + if time_offset == expected_query_time: + assert got_query.is_set() + got_query.clear() + if next_query_interval == s._BROWSER_BACKOFF_LIMIT: + # Only need to test up to the point where we've seen a query + # after the backoff limit has been hit + break + elif next_query_interval == 0: + next_query_interval = initial_query_interval + expected_query_time = initial_query_interval + else: + next_query_interval = min(2 * next_query_interval, s._BROWSER_BACKOFF_LIMIT) + expected_query_time += next_query_interval + else: + assert not got_query.is_set() + time_offset += initial_query_interval + + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_integration(): + service_added = Event() + service_removed = Event() + unexpected_ttl = Event() + got_query = Event() + + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + def on_service_state_change(zeroconf, service_type, state_change, name): + if name == registration_name: + if state_change is ServiceStateChange.Added: + service_added.set() + elif state_change is ServiceStateChange.Removed: + service_removed.set() + + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to monkey patch the zeroconf send to check packet sizes + old_send = zeroconf_browser.send + + time_offset = 0.0 + + def current_time_millis(): + """Current system time in milliseconds""" + return time.time() * 1000 + time_offset * 1000 + + expected_ttl = r._DNS_HOST_TTL + + nbr_answers = 0 + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + pout = r.DNSIncoming(out.packet()) + nonlocal nbr_answers + for answer in pout.answers: + nbr_answers += 1 + if not answer.ttl > expected_ttl / 2: + unexpected_ttl.set() + + got_query.set() + old_send(out, addr=addr, port=port) + + # monkey patch the zeroconf send + setattr(zeroconf_browser, "send", send) + + # monkey patch the zeroconf current_time_millis + s.current_time_millis = current_time_millis + + # monkey patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL + s._BROWSER_BACKOFF_LIMIT = int(expected_ttl / 4) + + service_added = Event() + service_removed = Event() + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zeroconf_registrar.register_service(info) + + try: + service_added.wait(1) + assert service_added.is_set() + + # Test that we receive queries containing answers only if the remaining TTL + # is greater than half the original TTL + sleep_count = 0 + test_iterations = 50 + while nbr_answers < test_iterations: + # Increase simulated time shift by 1/4 of the TTL in seconds + time_offset += expected_ttl / 4 + zeroconf_browser.notify_all() + sleep_count += 1 + got_query.wait(0.1) + got_query.clear() + # Prevent the test running indefinitely in an error condition + assert sleep_count < test_iterations * 4 + assert not unexpected_ttl.is_set() + + # Don't remove service, allow close() to cleanup + + finally: + zeroconf_registrar.close() + service_removed.wait(1) + assert service_removed.is_set() + browser.cancel() + zeroconf_browser.close() + + +def test_legacy_record_update_listener(): + """Test a RecordUpdateListener that does not implement update_records.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(RuntimeError): + r.RecordUpdateListener().update_record( + zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + ) + + updates = [] + + class LegacyRecordUpdateListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: + nonlocal updates + updates.append(record) + + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + zc.wait(1) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + + zc.close() diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 6f5d9c9c..8faba304 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -28,11 +28,9 @@ import sys import threading import time -import warnings -from collections import OrderedDict from types import TracebackType # noqa # used in type hints from typing import Dict, List, Optional, Type, Union, cast -from typing import Any, Callable, Set, Tuple # noqa # used in type hints +from typing import Set, Tuple # noqa # used in type hints from .const import ( # noqa # import needed for backwards compat _BROWSER_BACKOFF_LIMIT, @@ -108,6 +106,14 @@ ServiceNameAlreadyRegistered, ) from .logger import QuietLogger, log +from .services import ( # noqa # import needed for backwards compat + Signal, + SignalRegistrationInterface, + RecordUpdateListener, + _ServiceBrowserBase, + ServiceBrowser, + ServiceInfo, +) from .utils.name import service_type_name from .utils.net import ( # noqa # import needed for backwards compat add_multicast_member, @@ -123,7 +129,7 @@ _encode_address, get_all_addresses, ) -from .utils.struct import int2byte +from .utils.struct import int2byte # noqa # import needed for backwards compat from .utils.time import current_time_millis, millis_to_seconds __author__ = 'Paul Scott-Murphy, William McBrine' @@ -317,66 +323,6 @@ def handle_read(self, socket_: socket.socket) -> None: self.zc.handle_response(msg) -class Signal: - def __init__(self) -> None: - self._handlers = [] # type: List[Callable[..., None]] - - def fire(self, **kwargs: Any) -> None: - for h in list(self._handlers): - h(**kwargs) - - @property - def registration_interface(self) -> 'SignalRegistrationInterface': - return SignalRegistrationInterface(self._handlers) - - -class SignalRegistrationInterface: - def __init__(self, handlers: List[Callable[..., None]]) -> None: - self._handlers = handlers - - def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': - self._handlers.append(handler) - return self - - def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': - self._handlers.remove(handler) - return self - - -class RecordUpdateListener: - def update_record( # pylint: disable=no-self-use - self, zc: 'Zeroconf', now: float, record: DNSRecord - ) -> None: - """Update a single record. - - This method is deprecated and will be removed in a future version. - update_records should be implemented instead. - """ - raise RuntimeError("update_record is deprecated and will be removed in a future version.") - - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Update multiple records in one shot. - - All records that are received in a single packet are passed - to update_records. - - This implementation is a compatiblity shim to ensure older code - that uses RecordUpdateListener as a base class will continue to - get calls to update_record. This method will raise - NotImplementedError in a future version. - - At this point the cache will not have the new records - """ - for record in records: - self.update_record(zc, now, record) - - def update_records_complete(self) -> None: - """Called when a record update has completed for all handlers. - - At this point the cache will have the new records. - """ - - class ServiceListener: def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() @@ -396,652 +342,6 @@ def notify_all(self) -> None: raise NotImplementedError() -class _ServiceBrowserBase(RecordUpdateListener): - """Base class for ServiceBrowser.""" - - def __init__( - self, - zc: 'Zeroconf', - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - """Creates a browser for a specific type""" - assert handlers or listener, 'You need to specify at least one handler' - self.types = set(type_ if isinstance(type_, list) else [type_]) # type: Set[str] - for check_type_ in self.types: - if not check_type_.endswith(service_type_name(check_type_, strict=False)): - raise BadTypeInNameException - self.zc = zc - self.addr = addr - self.port = port - self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - self._services = { - check_type_: {} for check_type_ in self.types - } # type: Dict[str, Dict[str, DNSRecord]] - current_time = current_time_millis() - self._next_time = {check_type_: current_time for check_type_ in self.types} - self._delay = {check_type_: delay for check_type_ in self.types} - self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] - self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] - - self._service_state_changed = Signal() - - self.done = False - - if hasattr(handlers, 'add_service'): - listener = cast(ServiceListener, handlers) - handlers = None - - handlers = cast(List[Callable[..., None]], handlers or []) - - if listener: - - def on_change( - zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange - ) -> None: - assert listener is not None - args = (zeroconf, service_type, name) - if state_change is ServiceStateChange.Added: - listener.add_service(*args) - elif state_change is ServiceStateChange.Removed: - listener.remove_service(*args) - elif state_change is ServiceStateChange.Updated: - if hasattr(listener, 'update_service'): - listener.update_service(*args) - else: - warnings.warn( - "%r has no update_service method. Provide one (it can be empty if you " - "don't care about the updates), it'll become mandatory." % (listener,), - FutureWarning, - ) - else: - raise NotImplementedError(state_change) - - handlers.append(on_change) - - for h in handlers: - self.service_state_changed.register_handler(h) - - @property - def service_state_changed(self) -> SignalRegistrationInterface: - return self._service_state_changed.registration_interface - - def _record_matching_type(self, record: DNSRecord) -> Optional[str]: - """Return the type if the record matches one of the types we are browsing.""" - return next((type_ for type_ in self.types if record.name.endswith(type_)), None) - - def _enqueue_callback( - self, - state_change: ServiceStateChange, - type_: str, - name: str, - ) -> None: - # Code to ensure we only do a single update message - # Precedence is; Added, Remove, Update - key = (name, type_) - if ( - state_change is ServiceStateChange.Added - or ( - state_change is ServiceStateChange.Removed - and self._pending_handlers.get(key) != ServiceStateChange.Added - ) - or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) - ): - self._pending_handlers[key] = state_change - - def _process_record_update( - self, - zc: 'Zeroconf', - now: float, - record: DNSRecord, - ) -> None: - """Process a single record update from a batch of updates.""" - expired = record.is_expired(now) - - if isinstance(record, DNSPointer): - if record.name not in self.types: - return - service_key = record.alias.lower() - services_by_type = self._services[record.name] - old_record = services_by_type.get(service_key) - if old_record is None: - services_by_type[service_key] = record - self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) - elif expired: - del services_by_type[service_key] - self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) - else: - old_record.reset_ttl(record) - expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) - if expires < self._next_time[record.name]: - self._next_time[record.name] = expires - return - - # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.get(record): - return - - if isinstance(record, DNSAddress): - # Only trigger an updated event if the address is new - if record.address in set( - service.address - for service in zc.cache.entries_with_name(record.name) - if isinstance(service, DNSAddress) - ): - return - - # Iterate through the DNSCache and callback any services that use this address - for service in self.zc.cache.entries_with_server(record.name): - type_ = self._record_matching_type(service) - if type_: - self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) - break - - return - - type_ = self._record_matching_type(record) - if type_: - self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) - - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Callback invoked by Zeroconf when new information arrives. - - Updates information required by browser in the Zeroconf cache. - - Ensures that there is are no unecessary duplicates in the list. - """ - for record in records: - self._process_record_update(zc, now, record) - - def update_records_complete(self) -> None: - """Called when a record update has completed for all handlers. - - At this point the cache will have the new records. - """ - self._handlers_to_call.update(self._pending_handlers) - self._pending_handlers.clear() - - def cancel(self) -> None: - """Cancel the browser.""" - self.done = True - self.zc.remove_listener(self) - - def run(self) -> None: - """Run the browser.""" - questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] - self.zc.add_listener(self, questions) - - def generate_ready_queries(self) -> Optional[DNSOutgoing]: - """Generate the service browser query for any type that is due.""" - out = None - now = current_time_millis() - - if min(self._next_time.values()) > now: - return out - - for type_, due in self._next_time.items(): - if due > now: - continue - - if out is None: - out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) - out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) - - for record in self._services[type_].values(): - if not record.is_stale(now): - out.add_answer_at_time(record, now) - - self._next_time[type_] = now + self._delay[type_] - self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - return out - - def _seconds_to_wait(self) -> Optional[float]: - """Returns the number of seconds to wait for the next event.""" - # If there are handlers to call - # we want to process them right away - if self._handlers_to_call: - return None - - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - now = current_time_millis() - - if next_time <= now: - return None - - return millis_to_seconds(next_time - now) - - -class ServiceBrowser(_ServiceBrowserBase, threading.Thread): - """Used to browse for a service of a specific type. - - The listener object will have its add_service() and - remove_service() methods called when this browser - discovers changes in the services availability.""" - - def __init__( - self, - zc: 'Zeroconf', - type_: Union[str, list], - handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[ServiceListener] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - threading.Thread.__init__(self) - super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) - self.daemon = True - self.start() - self.name = "zeroconf-ServiceBrowser-%s-%s" % ( - '-'.join([type_[:-7] for type_ in self.types]), - getattr(self, 'native_id', self.ident), - ) - - def cancel(self) -> None: - """Cancel the browser.""" - super().cancel() - self.join() - - def run(self) -> None: - """Run the browser thread.""" - super().run() - while True: - timeout = self._seconds_to_wait() - if timeout: - with self.zc.condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - self.zc.condition.wait(timeout) - - if self.zc.done or self.done: - return - - out = self.generate_ready_queries() - if out: - self.zc.send(out, addr=self.addr, port=self.port) - - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) - - -class ServiceInfo(RecordUpdateListener): - """Service information. - - Constructor parameters are as follows: - - * `type_`: fully qualified service type name - * `name`: fully qualified service name - * `port`: port that the service runs on - * `weight`: weight of the service - * `priority`: priority of the service - * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). - converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to - value-less attributes. - * `server`: fully qualified name for service host (defaults to name) - * `host_ttl`: ttl used for A/SRV records - * `other_ttl`: ttl used for PTR/TXT records - * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, - or in parsed form as text; at most one of those parameters can be provided) - - """ - - text = b'' - - def __init__( - self, - type_: str, - name: str, - port: Optional[int] = None, - weight: int = 0, - priority: int = 0, - properties: Union[bytes, Dict] = b'', - server: Optional[str] = None, - host_ttl: int = _DNS_HOST_TTL, - other_ttl: int = _DNS_OTHER_TTL, - *, - addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None - ) -> None: - # Accept both none, or one, but not both. - if addresses is not None and parsed_addresses is not None: - raise TypeError("addresses and parsed_addresses cannot be provided together") - if not type_.endswith(service_type_name(name, strict=False)): - raise BadTypeInNameException - self.type = type_ - self.name = name - self.key = name.lower() - if addresses is not None: - self._addresses = addresses - elif parsed_addresses is not None: - self._addresses = [_encode_address(a) for a in parsed_addresses] - else: - self._addresses = [] - # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] - if invalid: - raise TypeError( - 'Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid - ) - self.port = port - self.weight = weight - self.priority = priority - if server: - self.server = server - else: - self.server = name - self.server_key = self.server.lower() - self._properties = {} # type: Dict - self._set_properties(properties) - self.host_ttl = host_ttl - self.other_ttl = other_ttl - - @property - def addresses(self) -> List[bytes]: - """IPv4 addresses of this service. - - Only IPv4 addresses are returned for backward compatibility. - Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to - include IPv6 addresses as well. - """ - return self.addresses_by_version(IPVersion.V4Only) - - @addresses.setter - def addresses(self, value: List[bytes]) -> None: - """Replace the addresses list. - - This replaces all currently stored addresses, both IPv4 and IPv6. - """ - self._addresses = value - - @property - def properties(self) -> Dict: - """If properties were set in the constructor this property returns the original dictionary - of type `Dict[Union[bytes, str], Any]`. - - If properties are coming from the network, after decoding a TXT record, the keys are always - bytes and the values are either bytes, if there was a value, even empty, or `None`, if there - was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. - """ - return self._properties - - def addresses_by_version(self, version: IPVersion) -> List[bytes]: - """List addresses matching IP version.""" - if version == IPVersion.V4Only: - return [addr for addr in self._addresses if not _is_v6_address(addr)] - if version == IPVersion.V6Only: - return list(filter(_is_v6_address, self._addresses)) - return self._addresses - - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: - """List addresses in their parsed string form.""" - result = self.addresses_by_version(version) - return [ - socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) - for addr in result - ] - - def _set_properties(self, properties: Union[bytes, Dict]) -> None: - """Sets properties and text of this info from a dictionary""" - if isinstance(properties, dict): - self._properties = properties - list_ = [] - result = b'' - for key, value in properties.items(): - if isinstance(key, str): - key = key.encode('utf-8') - - record = key - if value is not None: - if not isinstance(value, bytes): - value = str(value).encode('utf-8') - record += b'=' + value - list_.append(record) - for item in list_: - result = b''.join((result, int2byte(len(item)), item)) - self.text = result - else: - self.text = properties - - def _set_text(self, text: bytes) -> None: - """Sets properties and text given a text field""" - self.text = text - result = {} # type: Dict - end = len(text) - index = 0 - strs = [] - while index < end: - length = text[index] - index += 1 - strs.append(text[index : index + length]) - index += length - - for s in strs: - parts = s.split(b'=', 1) - try: - key, value = parts # type: Tuple[bytes, Optional[bytes]] - except ValueError: - # No equals sign at all - key = s - value = None - - # Only update non-existent properties - if key and result.get(key) is None: - result[key] = value - - self._properties = result - - def get_name(self) -> str: - """Name accessor""" - return self.name[: len(self.name) - len(self.type) - 1] - - def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record. - - This method is deprecated and will be removed in a future version. - update_records should be implemented instead. - """ - if record is not None: - self.update_records(zc, now, [record]) - - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Updates service information from a DNS record.""" - update_addresses = False - for record in records: - if isinstance(record, DNSService): - update_addresses = True - self._process_record(record, now) - - # Only update addresses if the DNSService (.server) has changed - if not update_addresses: - return - - for record in self._get_address_records_from_cache(zc): - self._process_record(record, now) - - def _process_record(self, record: DNSRecord, now: float) -> None: - if record.is_expired(now): - return - - if isinstance(record, DNSAddress): - if record.key == self.server_key and record.address not in self._addresses: - self._addresses.append(record.address) - return - - if isinstance(record, DNSService): - if record.key != self.key: - return - self.name = record.name - self.server = record.server - self.server_key = record.server.lower() - self.port = record.port - self.weight = record.weight - self.priority = record.priority - return - - if isinstance(record, DNSText): - if record.key == self.key: - self._set_text(record.text) - - def dns_addresses(self, override_ttl: Optional[int] = None) -> List[DNSAddress]: - """Return matching DNSAddress from ServiceInfo.""" - return [ - DNSAddress( - self.server, - _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - address, - ) - for address in self._addresses - ] - - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: - """Return DNSPointer from ServiceInfo.""" - return DNSPointer( - self.type, - _TYPE_PTR, - _CLASS_IN, - override_ttl if override_ttl is not None else self.other_ttl, - self.name, - ) - - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: - """Return DNSService from ServiceInfo.""" - return DNSService( - self.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - self.priority, - self.weight, - cast(int, self.port), - self.server, - ) - - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: - """Return DNSText from ServiceInfo.""" - return DNSText( - self.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.other_ttl, - self.text, - ) - - def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: - """Get the address records from the cache.""" - address_records = [] - cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) - if cached_a_record: - address_records.append(cached_a_record) - address_records.extend(zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) - return address_records - - def load_from_cache(self, zc: 'Zeroconf') -> bool: - """Populate the service info from the cache.""" - now = current_time_millis() - record_updates = [] - cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) - if cached_srv_record: - # If there is a srv record, A and AAAA will already - # be called and we do not want to do it twice - record_updates.append(cached_srv_record) - else: - record_updates.extend(self._get_address_records_from_cache(zc)) - cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) - if cached_txt_record: - record_updates.append(cached_txt_record) - self.update_records(zc, now, record_updates) - return self._is_complete - - @property - def _is_complete(self) -> bool: - """The ServiceInfo has all expected properties.""" - return not (self.text is None or not self._addresses) - - def request(self, zc: 'Zeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - if self.load_from_cache(zc): - return True - - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - try: - # Do not set a question on the listener to preload from cache - # since we just checked it above in load_from_cache - zc.add_listener(self, None) - while not self._is_complete: - if last <= now: - return False - if next_ <= now: - out = self.generate_request_query(zc, now) - if not out.questions: - return True - zc.send(out) - next_ = now + delay - delay *= 2 - - zc.wait(min(next_, last) - now) - now = current_time_millis() - finally: - zc.remove_listener(self) - - return True - - def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: - """Generate the request query.""" - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) - return out - - def __eq__(self, other: object) -> bool: - """Tests equality of service name""" - return isinstance(other, ServiceInfo) and other.name == self.name - - def __repr__(self) -> str: - """String representation""" - return '%s(%s)' % ( - type(self).__name__, - ', '.join( - '%s=%r' % (name, getattr(self, name)) - for name in ( - 'type', - 'name', - 'addresses', - 'port', - 'weight', - 'priority', - 'server', - 'properties', - ) - ), - ) - - class ZeroconfServiceTypes(ServiceListener): """ Return all of the advertised services on any local networks diff --git a/zeroconf/services.py b/zeroconf/services.py new file mode 100644 index 00000000..5c61741e --- /dev/null +++ b/zeroconf/services.py @@ -0,0 +1,771 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import socket +import threading +import warnings +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast + +from .const import ( + _BROWSER_BACKOFF_LIMIT, + _BROWSER_TIME, + _CLASS_IN, + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _EXPIRE_REFRESH_TIME_PERCENT, + _FLAGS_QR_QUERY, + _LISTENER_TIME, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _TYPE_A, + _TYPE_AAAA, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) +from .dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from .exceptions import BadTypeInNameException +from .utils.name import service_type_name +from .utils.net import ( + IPVersion, + ServiceStateChange, + _encode_address, + _is_v6_address, +) +from .utils.struct import int2byte +from .utils.time import current_time_millis, millis_to_seconds + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from . import ( # pylint: disable=cyclic-import + ServiceListener, + Zeroconf, + ) + + +class Signal: + def __init__(self) -> None: + self._handlers = [] # type: List[Callable[..., None]] + + def fire(self, **kwargs: Any) -> None: + for h in list(self._handlers): + h(**kwargs) + + @property + def registration_interface(self) -> 'SignalRegistrationInterface': + return SignalRegistrationInterface(self._handlers) + + +class SignalRegistrationInterface: + def __init__(self, handlers: List[Callable[..., None]]) -> None: + self._handlers = handlers + + def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.append(handler) + return self + + def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': + self._handlers.remove(handler) + return self + + +class RecordUpdateListener: + def update_record( # pylint: disable=no-self-use + self, zc: 'Zeroconf', now: float, record: DNSRecord + ) -> None: + """Update a single record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + raise RuntimeError("update_record is deprecated and will be removed in a future version.") + + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Update multiple records in one shot. + + All records that are received in a single packet are passed + to update_records. + + This implementation is a compatiblity shim to ensure older code + that uses RecordUpdateListener as a base class will continue to + get calls to update_record. This method will raise + NotImplementedError in a future version. + + At this point the cache will not have the new records + """ + for record in records: + self.update_record(zc, now, record) + + def update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + """ + + +class _ServiceBrowserBase(RecordUpdateListener): + """Base class for ServiceBrowser.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, + listener: Optional['ServiceListener'] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + """Creates a browser for a specific type""" + assert handlers or listener, 'You need to specify at least one handler' + self.types = set(type_ if isinstance(type_, list) else [type_]) # type: Set[str] + for check_type_ in self.types: + if not check_type_.endswith(service_type_name(check_type_, strict=False)): + raise BadTypeInNameException + self.zc = zc + self.addr = addr + self.port = port + self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) + self._services = { + check_type_: {} for check_type_ in self.types + } # type: Dict[str, Dict[str, DNSRecord]] + current_time = current_time_millis() + self._next_time = {check_type_: current_time for check_type_ in self.types} + self._delay = {check_type_: delay for check_type_ in self.types} + self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] + self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] + + self._service_state_changed = Signal() + + self.done = False + + if hasattr(handlers, 'add_service'): + listener = cast('ServiceListener', handlers) + handlers = None + + handlers = cast(List[Callable[..., None]], handlers or []) + + if listener: + + def on_change( + zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange + ) -> None: + assert listener is not None + args = (zeroconf, service_type, name) + if state_change is ServiceStateChange.Added: + listener.add_service(*args) + elif state_change is ServiceStateChange.Removed: + listener.remove_service(*args) + elif state_change is ServiceStateChange.Updated: + if hasattr(listener, 'update_service'): + listener.update_service(*args) + else: + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) + else: + raise NotImplementedError(state_change) + + handlers.append(on_change) + + for h in handlers: + self.service_state_changed.register_handler(h) + + @property + def service_state_changed(self) -> SignalRegistrationInterface: + return self._service_state_changed.registration_interface + + def _record_matching_type(self, record: DNSRecord) -> Optional[str]: + """Return the type if the record matches one of the types we are browsing.""" + return next((type_ for type_ in self.types if record.name.endswith(type_)), None) + + def _enqueue_callback( + self, + state_change: ServiceStateChange, + type_: str, + name: str, + ) -> None: + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + key = (name, type_) + if ( + state_change is ServiceStateChange.Added + or ( + state_change is ServiceStateChange.Removed + and self._pending_handlers.get(key) != ServiceStateChange.Added + ) + or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) + ): + self._pending_handlers[key] = state_change + + def _process_record_update( + self, + zc: 'Zeroconf', + now: float, + record: DNSRecord, + ) -> None: + """Process a single record update from a batch of updates.""" + expired = record.is_expired(now) + + if isinstance(record, DNSPointer): + if record.name not in self.types: + return + service_key = record.alias.lower() + services_by_type = self._services[record.name] + old_record = services_by_type.get(service_key) + if old_record is None: + services_by_type[service_key] = record + self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + elif expired: + del services_by_type[service_key] + self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) + else: + old_record.reset_ttl(record) + expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) + if expires < self._next_time[record.name]: + self._next_time[record.name] = expires + return + + # If its expired or already exists in the cache it cannot be updated. + if expired or self.zc.cache.get(record): + return + + if isinstance(record, DNSAddress): + # Only trigger an updated event if the address is new + if record.address in set( + service.address + for service in zc.cache.entries_with_name(record.name) + if isinstance(service, DNSAddress) + ): + return + + # Iterate through the DNSCache and callback any services that use this address + for service in self.zc.cache.entries_with_server(record.name): + type_ = self._record_matching_type(service) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) + break + + return + + type_ = self._record_matching_type(record) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) + + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Callback invoked by Zeroconf when new information arrives. + + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unecessary duplicates in the list. + """ + for record in records: + self._process_record_update(zc, now, record) + + def update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + """ + self._handlers_to_call.update(self._pending_handlers) + self._pending_handlers.clear() + + def cancel(self) -> None: + """Cancel the browser.""" + self.done = True + self.zc.remove_listener(self) + + def run(self) -> None: + """Run the browser.""" + questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] + self.zc.add_listener(self, questions) + + def generate_ready_queries(self) -> Optional[DNSOutgoing]: + """Generate the service browser query for any type that is due.""" + out = None + now = current_time_millis() + + if min(self._next_time.values()) > now: + return out + + for type_, due in self._next_time.items(): + if due > now: + continue + + if out is None: + out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) + out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) + + for record in self._services[type_].values(): + if not record.is_stale(now): + out.add_answer_at_time(record, now) + + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + return out + + def _seconds_to_wait(self) -> Optional[float]: + """Returns the number of seconds to wait for the next event.""" + # If there are handlers to call + # we want to process them right away + if self._handlers_to_call: + return None + + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + + if next_time <= now: + return None + + return millis_to_seconds(next_time - now) + + +class ServiceBrowser(_ServiceBrowserBase, threading.Thread): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, + listener: Optional['ServiceListener'] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + threading.Thread.__init__(self) + super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + self.daemon = True + self.start() + self.name = "zeroconf-ServiceBrowser-%s-%s" % ( + '-'.join([type_[:-7] for type_ in self.types]), + getattr(self, 'native_id', self.ident), + ) + + def cancel(self) -> None: + """Cancel the browser.""" + super().cancel() + self.join() + + def run(self) -> None: + """Run the browser thread.""" + super().run() + while True: + timeout = self._seconds_to_wait() + if timeout: + with self.zc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + self.zc.condition.wait(timeout) + + if self.zc.done or self.done: + return + + out = self.generate_ready_queries() + if out: + self.zc.send(out, addr=self.addr, port=self.port) + + if not self._handlers_to_call: + continue + + (name_type, state_change) = self._handlers_to_call.popitem(False) + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + +class ServiceInfo(RecordUpdateListener): + """Service information. + + Constructor parameters are as follows: + + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) + + """ + + text = b'' + + def __init__( + self, + type_: str, + name: str, + port: Optional[int] = None, + weight: int = 0, + priority: int = 0, + properties: Union[bytes, Dict] = b'', + server: Optional[str] = None, + host_ttl: int = _DNS_HOST_TTL, + other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None + ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") + if not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + self.type = type_ + self.name = name + self.key = name.lower() + if addresses is not None: + self._addresses = addresses + elif parsed_addresses is not None: + self._addresses = [_encode_address(a) for a in parsed_addresses] + else: + self._addresses = [] + # This results in an ugly error when registering, better check now + invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] + if invalid: + raise TypeError( + 'Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid + ) + self.port = port + self.weight = weight + self.priority = priority + if server: + self.server = server + else: + self.server = name + self.server_key = self.server.lower() + self._properties = {} # type: Dict + self._set_properties(properties) + self.host_ttl = host_ttl + self.other_ttl = other_ttl + + @property + def addresses(self) -> List[bytes]: + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value: List[bytes]) -> None: + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._addresses = value + + @property + def properties(self) -> Dict: + """If properties were set in the constructor this property returns the original dictionary + of type `Dict[Union[bytes, str], Any]`. + + If properties are coming from the network, after decoding a TXT record, the keys are always + bytes and the values are either bytes, if there was a value, even empty, or `None`, if there + was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. + """ + return self._properties + + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version.""" + if version == IPVersion.V4Only: + return [addr for addr in self._addresses if not _is_v6_address(addr)] + if version == IPVersion.V6Only: + return list(filter(_is_v6_address, self._addresses)) + return self._addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form.""" + result = self.addresses_by_version(version) + return [ + socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) + for addr in result + ] + + def _set_properties(self, properties: Union[bytes, Dict]) -> None: + """Sets properties and text of this info from a dictionary""" + if isinstance(properties, dict): + self._properties = properties + list_ = [] + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, int2byte(len(item)), item)) + self.text = result + else: + self.text = properties + + def _set_text(self, text: bytes) -> None: + """Sets properties and text given a text field""" + self.text = text + result = {} # type: Dict + end = len(text) + index = 0 + strs = [] + while index < end: + length = text[index] + index += 1 + strs.append(text[index : index + length]) + index += length + + for s in strs: + parts = s.split(b'=', 1) + try: + key, value = parts # type: Tuple[bytes, Optional[bytes]] + except ValueError: + # No equals sign at all + key = s + value = None + + # Only update non-existent properties + if key and result.get(key) is None: + result[key] = value + + self._properties = result + + def get_name(self) -> str: + """Name accessor""" + return self.name[: len(self.name) - len(self.type) - 1] + + def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + if record is not None: + self.update_records(zc, now, [record]) + + def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Updates service information from a DNS record.""" + update_addresses = False + for record in records: + if isinstance(record, DNSService): + update_addresses = True + self._process_record(record, now) + + # Only update addresses if the DNSService (.server) has changed + if not update_addresses: + return + + for record in self._get_address_records_from_cache(zc): + self._process_record(record, now) + + def _process_record(self, record: DNSRecord, now: float) -> None: + if record.is_expired(now): + return + + if isinstance(record, DNSAddress): + if record.key == self.server_key and record.address not in self._addresses: + self._addresses.append(record.address) + return + + if isinstance(record, DNSService): + if record.key != self.key: + return + self.name = record.name + self.server = record.server + self.server_key = record.server.lower() + self.port = record.port + self.weight = record.weight + self.priority = record.priority + return + + if isinstance(record, DNSText): + if record.key == self.key: + self._set_text(record.text) + + def dns_addresses(self, override_ttl: Optional[int] = None) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return [ + DNSAddress( + self.server, + _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + address, + ) + for address in self._addresses + ] + + def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self.name, + ) + + def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return DNSService( + self.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + cast(int, self.port), + self.server, + ) + + def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return DNSText( + self.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + ) + + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: + """Get the address records from the cache.""" + address_records = [] + cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) + if cached_a_record: + address_records.append(cached_a_record) + address_records.extend(zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) + return address_records + + def load_from_cache(self, zc: 'Zeroconf') -> bool: + """Populate the service info from the cache.""" + now = current_time_millis() + record_updates = [] + cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + # If there is a srv record, A and AAAA will already + # be called and we do not want to do it twice + record_updates.append(cached_srv_record) + else: + record_updates.extend(self._get_address_records_from_cache(zc)) + cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + record_updates.append(cached_txt_record) + self.update_records(zc, now, record_updates) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return not (self.text is None or not self._addresses) + + def request(self, zc: 'Zeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(zc): + return True + + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + # Do not set a question on the listener to preload from cache + # since we just checked it above in load_from_cache + zc.add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query(zc, now) + if not out.questions: + return True + zc.send(out) + next_ = now + delay + delay *= 2 + + zc.wait(min(next_, last) - now) + now = current_time_millis() + finally: + zc.remove_listener(self) + + return True + + def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) + return out + + def __eq__(self, other: object) -> bool: + """Tests equality of service name""" + return isinstance(other, ServiceInfo) and other.name == self.name + + def __repr__(self) -> str: + """String representation""" + return '%s(%s)' % ( + type(self).__name__, + ', '.join( + '%s=%r' % (name, getattr(self, name)) + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + ) + ), + ) From bf0e867ead1e48e05a27fe8db69900d9dc387ea2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 15:05:04 -1000 Subject: [PATCH 238/608] Relocate core functions into zeroconf.core (#547) --- tests/test_core.py | 63 ++++ tests/test_init.py | 23 -- zeroconf/__init__.py | 837 +---------------------------------------- zeroconf/aio.py | 11 +- zeroconf/core.py | 878 +++++++++++++++++++++++++++++++++++++++++++ zeroconf/services.py | 10 + 6 files changed, 960 insertions(+), 862 deletions(-) create mode 100644 tests/test_core.py create mode 100644 zeroconf/core.py diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 00000000..5535ab59 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf.core """ + +import itertools +import logging +import threading +import time +import unittest +import unittest.mock + + +import pytest +import zeroconf as r +from zeroconf import core + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestReaper(unittest.TestCase): + @unittest.mock.patch.object(core, "_CACHE_CLEANUP_INTERVAL", 10) + def test_reaper(self): + zeroconf = core.Zeroconf(interfaces=['127.0.0.1']) + cache = zeroconf.cache + original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + zeroconf.cache.add(record_with_10s_ttl) + zeroconf.cache.add(record_with_1s_ttl) + entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + time.sleep(1) + with zeroconf.engine.condition: + zeroconf.engine._notify() + time.sleep(0.1) + entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + zeroconf.close() + assert entries != original_entries + assert entries_with_cache != original_entries + assert record_with_10s_ttl in entries + assert record_with_1s_ttl not in entries diff --git a/tests/test_init.py b/tests/test_init.py index a740cab4..99424a54 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1235,29 +1235,6 @@ def test_cache_empty_multiple_calls_does_not_throw(self): assert 'a' not in cache.cache -class TestReaper(unittest.TestCase): - @unittest.mock.patch.object(r, "_CACHE_CLEANUP_INTERVAL", 10) - def test_reaper(self): - zeroconf = Zeroconf(interfaces=['127.0.0.1']) - cache = zeroconf.cache - original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') - record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - zeroconf.cache.add(record_with_10s_ttl) - zeroconf.cache.add(record_with_1s_ttl) - entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - time.sleep(1) - with zeroconf.engine.condition: - zeroconf.engine._notify() - time.sleep(0.1) - entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - zeroconf.close() - assert entries != original_entries - assert entries_with_cache != original_entries - assert record_with_10s_ttl in entries - assert record_with_1s_ttl not in entries - - class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 8faba304..ba473415 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -20,16 +20,9 @@ USA """ -import errno -import itertools -import platform -import select -import socket import sys -import threading import time -from types import TracebackType # noqa # used in type hints -from typing import Dict, List, Optional, Type, Union, cast +from typing import Optional, Union from typing import Set, Tuple # noqa # used in type hints from .const import ( # noqa # import needed for backwards compat @@ -83,6 +76,7 @@ _TYPE_TXT, _UNREGISTER_TIME, ) +from .core import NotifyListener, ServiceRegistry, Zeroconf # noqa # import needed for backwards compat from .dns import ( # noqa # import needed for backwards compat DNSAddress, DNSCache, @@ -105,8 +99,9 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) -from .logger import QuietLogger, log +from .logger import QuietLogger, log # noqa # import needed for backwards compat from .services import ( # noqa # import needed for backwards compat + instance_name_from_service_info, Signal, SignalRegistrationInterface, RecordUpdateListener, @@ -114,7 +109,7 @@ ServiceBrowser, ServiceInfo, ) -from .utils.name import service_type_name +from .utils.name import service_type_name # noqa # import needed for backwards compat from .utils.net import ( # noqa # import needed for backwards compat add_multicast_member, can_send_to, @@ -130,7 +125,7 @@ get_all_addresses, ) from .utils.struct import int2byte # noqa # import needed for backwards compat -from .utils.time import current_time_millis, millis_to_seconds +from .utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' @@ -160,169 +155,9 @@ ) -# utility functions - - -def instance_name_from_service_info(info: "ServiceInfo") -> str: - """Calculate the instance name from the ServiceInfo.""" - # This is kind of funky because of the subtype based tests - # need to make subtypes a first class citizen - service_name = service_type_name(info.name) - if not info.type.endswith(service_name): - raise BadTypeInNameException - return info.name[: -len(service_name) - 1] - - # implementation classes -class Engine(threading.Thread): - - """An engine wraps read access to sockets, allowing objects that - need to receive data from sockets to be called back when the - sockets are ready. - - A reader needs a handle_read() method, which is called when the socket - it is interested in is ready for reading. - - Writers are not implemented here, because we only send short - packets. - """ - - def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self) - self.daemon = True - self.zc = zc - self.readers = {} # type: Dict[socket.socket, Listener] - self.timeout = 5 - self.condition = threading.Condition() - self.socketpair = socket.socketpair() - self._last_cache_cleanup = 0.0 - self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) - - def run(self) -> None: - while not self.zc.done: - try: - rr, _wr, _er = select.select([*self.readers.keys(), self.socketpair[0]], [], [], self.timeout) - - if self.zc.done: - return - - for socket_ in rr: - reader = self.readers.get(socket_) - if reader: - reader.handle_read(socket_) - - if self.socketpair[0] in rr: - # Clear the socket's buffer - self.socketpair[0].recv(128) - - except (select.error, socket.error) as e: - # If the socket was closed by another thread, during - # shutdown, ignore it and exit - if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: - raise - - now = current_time_millis() - if now - self._last_cache_cleanup >= _CACHE_CLEANUP_INTERVAL: - self._last_cache_cleanup = now - self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) - self.zc.record_manager.updates_complete() - - self.socketpair[0].close() - self.socketpair[1].close() - - def _notify(self) -> None: - self.condition.notify() - try: - self.socketpair[1].send(b'x') - except socket.error: - # The socketpair may already be closed during shutdown, ignore it - if not self.zc.done: - raise - - def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None: - with self.condition: - self.readers[socket_] = reader - self._notify() - - def del_reader(self, socket_: socket.socket) -> None: - with self.condition: - del self.readers[socket_] - self._notify() - - -class Listener(QuietLogger): - - """A Listener is used by this module to listen on the multicast - group to which DNS messages are sent, allowing the implementation - to cache information as it arrives. - - It requires registration with an Engine object in order to have - the read() method called when a socket is available for reading.""" - - def __init__(self, zc: 'Zeroconf') -> None: - self.zc = zc - self.data = None # type: Optional[bytes] - - def handle_read(self, socket_: socket.socket) -> None: - try: - data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) - except Exception: # pylint: disable=broad-except - self.log_exception_warning('Error reading from socket %d', socket_.fileno()) - return - - if self.data == data: - log.debug( - 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', - addr, - port, - socket_.fileno(), - len(data), - data, - ) - return - - self.data = data - msg = DNSIncoming(data) - if msg.valid: - log.debug( - 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', - addr, - port, - socket_.fileno(), - msg, - len(data), - data, - ) - else: - log.debug( - 'Received from %r:%r (socket %d): (%d bytes) [%r]', - addr, - port, - socket_.fileno(), - len(data), - data, - ) - - if not msg.valid: - pass - - elif msg.is_query(): - # Always multicast responses - if port == _MDNS_PORT: - self.zc.handle_query(msg, None, _MDNS_PORT) - - # If it's not a multicast query, reply via unicast - # and multicast - elif port == _DNS_PORT: - self.zc.handle_query(msg, addr, port) - self.zc.handle_query(msg, None, _MDNS_PORT) - - else: - self.zc.handle_response(msg) - - class ServiceListener: def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() @@ -334,14 +169,6 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() -class NotifyListener: - """Receive notifications Zeroconf.notify_all is called.""" - - def notify_all(self) -> None: - """Called when Zeroconf.notify_all is called.""" - raise NotImplementedError() - - class ZeroconfServiceTypes(ServiceListener): """ Return all of the advertised services on any local networks @@ -393,655 +220,3 @@ def find( local_zc.close() return tuple(sorted(listener.found_services)) - - -class ServiceRegistry: - """A registry to keep track of services. - - This class exists to ensure services can - be safely added and removed with thread - safety. - """ - - def __init__( - self, - ) -> None: - """Create the ServiceRegistry class.""" - self.services = {} # type: Dict[str, ServiceInfo] - self.types = {} # type: Dict[str, List] - self.servers = {} # type: Dict[str, List] - self._lock = threading.Lock() # add and remove services thread safe - - def add(self, info: ServiceInfo) -> None: - """Add a new service to the registry.""" - - with self._lock: - self._add(info) - - def remove(self, info: ServiceInfo) -> None: - """Remove a new service from the registry.""" - - with self._lock: - self._remove(info) - - def update(self, info: ServiceInfo) -> None: - """Update new service in the registry.""" - - with self._lock: - self._remove(info) - self._add(info) - - def get_service_infos(self) -> List[ServiceInfo]: - """Return all ServiceInfo.""" - return list(self.services.values()) - - def get_info_name(self, name: str) -> Optional[ServiceInfo]: - """Return all ServiceInfo for the name.""" - return self.services.get(name) - - def get_types(self) -> List[str]: - """Return all types.""" - return list(self.types.keys()) - - def get_infos_type(self, type_: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching type.""" - return self._get_by_index("types", type_) - - def get_infos_server(self, server: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching server.""" - return self._get_by_index("servers", server) - - def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching the index.""" - service_infos = [] - - for name in getattr(self, attr).get(key, [])[:]: - info = self.services.get(name) - # Since we do not get under a lock since it would be - # a performance issue, its possible - # the service can be unregistered during the get - # so we must check if info is None - if info is not None: - service_infos.append(info) - - return service_infos - - def _add(self, info: ServiceInfo) -> None: - """Add a new service under the lock.""" - lower_name = info.name.lower() - if lower_name in self.services: - raise ServiceNameAlreadyRegistered - - self.services[lower_name] = info - self.types.setdefault(info.type, []).append(lower_name) - self.servers.setdefault(info.server, []).append(lower_name) - - def _remove(self, info: ServiceInfo) -> None: - """Remove a service under the lock.""" - lower_name = info.name.lower() - old_service_info = self.services[lower_name] - self.types[old_service_info.type].remove(lower_name) - self.servers[old_service_info.server].remove(lower_name) - del self.services[lower_name] - - -class QueryHandler: - """Query the ServiceRegistry.""" - - def __init__(self, registry: ServiceRegistry): - """Init the query handler.""" - self.registry = registry - - def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: - """Provide an answer to a service type enumeration query. - - https://datatracker.ietf.org/doc/html/rfc6763#section-9 - """ - for stype in self.registry.get_types(): - out.add_answer( - msg, - DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - - def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name.lower()): - out.add_answer(msg, service.dns_pointer()) - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer(service.dns_service()) - out.add_additional_answer(service.dns_text()) - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a query any query other then PTR. - - Add answer(s) for A, AAAA, SRV, or TXT queries. - """ - name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(name_to_find) # type: ignore - if service is None: - return - - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: - """Deal with incoming query packets. Provides a response if possible.""" - if unicast: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) - for question in msg.questions: - out.add_question(question) - else: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._answer_service_type_enumeration_query(msg, out) - else: - self._answer_ptr_query(msg, out, question) - continue - - self._answer_non_ptr_query(msg, out, question) - - if out is not None and out.answers: - out.id = msg.id - return out - - return None - - -class RecordManager: - """Process records into the cache and notify listeners.""" - - def __init__(self, zeroconf: 'Zeroconf') -> None: - """Init the record manager.""" - self.zc = zeroconf - self.cache = zeroconf.cache - self.listeners: List[RecordUpdateListener] = [] - - def updates(self, now: float, rec: List[DNSRecord]) -> None: - """Used to notify listeners of new information that has updated - a record. - - This method must be called before the cache is updated. - """ - for listener in self.listeners: - listener.update_records(self.zc, now, rec) - - def updates_complete(self) -> None: - """Used to notify listeners of new information that has updated - a record. - - This method must be called after the cache is updated. - """ - for listener in self.listeners: - listener.update_records_complete() - self.zc.notify_all() - - def updates_from_response(self, msg: DNSIncoming) -> None: - """Deal with incoming response packets. All answers - are held in the cache, and listeners are notified.""" - updates: List[DNSRecord] = [] - address_adds: List[DNSAddress] = [] - other_adds: List[DNSRecord] = [] - removes: List[DNSRecord] = [] - now = current_time_millis() - for record in msg.answers: - - updated = True - - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False - if record.created - entry.created > 1000 and entry not in msg.answers: - removes.append(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) - else: - if isinstance(record, DNSAddress): - address_adds.append(record) - else: - other_adds.append(record) - if updated: - updates.append(record) - elif maybe_entry is not None: - updates.append(record) - removes.append(record) - - if not updates and not address_adds and not other_adds and not removes: - return - - self.updates(now, updates) - # The cache adds must be processed AFTER we trigger - # the updates since we compare existing data - # with the new data and updating the cache - # ahead of update_record will cause listeners - # to miss changes - # - # We must process address adds before non-addresses - # otherwise a fetch of ServiceInfo may miss an address - # because it thinks the cache is complete - # - # The cache is processed under the context manager to ensure - # that any ServiceBrowser that is going to call - # zc.get_service_info will see the cached value - # but ONLY after all the record updates have been - # processsed. - self.cache.add_records(itertools.chain(address_adds, other_adds)) - # Removes are processed last since - # ServiceInfo could generate an un-needed query - # because the data was not yet populated. - self.cache.remove_records(removes) - self.updates_complete() - - def add_listener( - self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] - ) -> None: - """Adds a listener for a given question. The listener will have - its update_record method called when information is available to - answer the question(s).""" - self.listeners.append(listener) - - if question is not None: - now = current_time_millis() - records = [] - questions = [question] if isinstance(question, DNSQuestion) else question - for single_question in questions: - for record in self.cache.entries_with_name(single_question.name): - if single_question.answered_by(record) and not record.is_expired(now): - records.append(record) - if records: - listener.update_records(self.zc, now, records) - listener.update_records_complete() - - self.zc.notify_all() - - def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" - try: - self.listeners.remove(listener) - self.zc.notify_all() - except ValueError as e: - log.exception('Failed to remove listener: %r', e) - - -class Zeroconf(QuietLogger): - - """Implementation of Zeroconf Multicast DNS Service Discovery - - Supports registration, unregistration, queries and browsing. - """ - - def __init__( - self, - interfaces: InterfacesType = InterfaceChoice.All, - unicast: bool = False, - ip_version: Optional[IPVersion] = None, - apple_p2p: bool = False, - ) -> None: - """Creates an instance of the Zeroconf class, establishing - multicast communications, listening and reaping threads. - - :param interfaces: :class:`InterfaceChoice` or a list of IP addresses - (IPv4 and IPv6) and interface indexes (IPv6 only). - - IPv6 notes for non-POSIX systems: - * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` - on Python versions before 3.8. - - Also listening on loopback (``::1``) doesn't work, use a real address. - :param ip_version: IP versions to support. If `choice` is a list, the default is detected - from it. Otherwise defaults to V4 only for backward compatibility. - :param apple_p2p: use AWDL interface (only macOS) - """ - if ip_version is None: - ip_version = autodetect_ip_version(interfaces) - - # hook for threads - self._GLOBAL_DONE = False - self.unicast = unicast - - if apple_p2p and not platform.system() == 'Darwin': - raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') - - self._listen_socket, self._respond_sockets = create_sockets( - interfaces, unicast, ip_version, apple_p2p=apple_p2p - ) - log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) - self.multi_socket = unicast or interfaces is not InterfaceChoice.Default - - self._notify_listeners = [] # type: List[NotifyListener] - self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] - self.registry = ServiceRegistry() - self.query_handler = QueryHandler(self.registry) - self.cache = DNSCache() - self.record_manager = RecordManager(self) - - self.condition = threading.Condition() - - self.engine = Engine(self) - self.listener = Listener(self) - if not unicast: - self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) - if self.multi_socket: - for s in self._respond_sockets: - self.engine.add_reader(self.listener, s) - # Start the engine only after all - # the readers have been added to avoid - # missing any packets that are on the wire - self.engine.start() - - @property - def done(self) -> bool: - return self._GLOBAL_DONE - - @property - def listeners(self) -> List[RecordUpdateListener]: - return self.record_manager.listeners - - def wait(self, timeout: float) -> None: - """Calling thread waits for a given number of milliseconds or - until notified.""" - with self.condition: - self.condition.wait(millis_to_seconds(timeout)) - - def notify_all(self) -> None: - """Notifies all waiting threads""" - with self.condition: - self.condition.notify_all() - for listener in self._notify_listeners: - listener.notify_all() - - def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: - """Returns network's service information for a particular - name and type, or None if no service matches by the timeout, - which defaults to 3 seconds.""" - info = ServiceInfo(type_, name) - if info.request(self, timeout): - return info - return None - - def add_notify_listener(self, listener: NotifyListener) -> None: - """Adds a listener to receive notify_all events.""" - self._notify_listeners.append(listener) - - def remove_notify_listener(self, listener: NotifyListener) -> None: - """Removes a listener from the set that is currently listening.""" - self._notify_listeners.remove(listener) - - def add_service_listener(self, type_: str, listener: ServiceListener) -> None: - """Adds a listener for a particular service type. This object - will then have its add_service and remove_service methods called when - services of that type become available and unavailable.""" - self.remove_service_listener(listener) - self.browsers[listener] = ServiceBrowser(self, type_, listener) - - def remove_service_listener(self, listener: ServiceListener) -> None: - """Removes a listener from the set that is currently listening.""" - if listener in self.browsers: - self.browsers[listener].cancel() - del self.browsers[listener] - - def remove_all_service_listeners(self) -> None: - """Removes a listener from the set that is currently listening.""" - for listener in list(self.browsers): - self.remove_service_listener(listener) - - def register_service( - self, - info: ServiceInfo, - ttl: Optional[int] = None, - allow_name_change: bool = False, - cooperating_responders: bool = False, - ) -> None: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service. The name of the service may be changed if needed to make - it unique on the network. Additionally multiple cooperating responders - can register the same service on the network for resilience - (if you want this behavior set `cooperating_responders` to `True`).""" - if ttl is not None: - # ttl argument is used to maintain backward compatibility - # Setting TTLs via ServiceInfo is preferred - info.host_ttl = ttl - info.other_ttl = ttl - self.check_service(info, allow_name_change, cooperating_responders) - self.registry.add(info) - self._broadcast_service(info, _REGISTER_TIME, None) - - def update_service(self, info: ServiceInfo) -> None: - """Registers service information to the network with a default TTL. - Zeroconf will then respond to requests for information for that - service.""" - - self.registry.update(info) - self._broadcast_service(info, _REGISTER_TIME, None) - - def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: - """Send a broadcasts to announce a service at intervals.""" - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - - self.send_service_broadcast(info, ttl) - i += 1 - next_time += interval - - def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: - """Send a broadcast to announce a service.""" - self.send(self.generate_service_broadcast(info, ttl)) - - def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: - """Generate a broadcast to announce a service.""" - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - self._add_broadcast_answer(out, info, ttl) - return out - - def send_service_query(self, info: ServiceInfo) -> None: - """Send a query to lookup a service.""" - self.send(self.generate_service_query(info)) - - def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use - """Generate a query to lookup a service.""" - out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) - out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) - out.add_authorative_answer(info.dns_pointer()) - return out - - def _add_broadcast_answer( # pylint: disable=no-self-use - self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] - ) -> None: - """Add answers to broadcast a service.""" - other_ttl = info.other_ttl if override_ttl is None else override_ttl - host_ttl = info.host_ttl if override_ttl is None else override_ttl - out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) - out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) - out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl): - out.add_answer_at_time(dns_address, 0) - - def unregister_service(self, info: ServiceInfo) -> None: - """Unregister a service.""" - self.registry.remove(info) - self._broadcast_service(info, _UNREGISTER_TIME, 0) - - def unregister_all_services(self) -> None: - """Unregister all registered services.""" - service_infos = self.registry.get_service_infos() - if not service_infos: - return - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - for info in service_infos: - self._add_broadcast_answer(out, info, 0) - self.send(out) - i += 1 - next_time += _UNREGISTER_TIME - - def check_service( - self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False - ) -> None: - """Checks the network for a unique service name, modifying the - ServiceInfo passed in if it is not unique.""" - instance_name = instance_name_from_service_info(info) - if cooperating_responders: - return - next_instance_number = 2 - next_time = now = current_time_millis() - i = 0 - while i < 3: - # check for a name conflict - while self.cache.current_entry_with_name_and_alias(info.type, info.name): - if not allow_name_change: - raise NonUniqueNameException - - # change the name and look for a conflict - info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) - next_instance_number += 1 - service_type_name(info.name) - next_time = now - i = 0 - - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - - self.send_service_query(info) - i += 1 - next_time += _CHECK_TIME - - def add_listener( - self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] - ) -> None: - """Adds a listener for a given question. The listener will have - its update_record method called when information is available to - answer the question(s).""" - self.record_manager.add_listener(listener, question) - - def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" - self.record_manager.remove_listener(listener) - - def handle_response(self, msg: DNSIncoming) -> None: - """Deal with incoming response packets. All answers - are held in the cache, and listeners are notified.""" - self.record_manager.updates_from_response(msg) - - def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: - """Deal with incoming query packets. Provides a response if - possible.""" - out = self.query_handler.response(msg, port != _MDNS_PORT) - if out: - self.send(out, addr, port) - - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: - """Sends an outgoing packet.""" - packets = out.packets() - packet_num = 0 - for packet in packets: - packet_num += 1 - if len(packet) > _MAX_MSG_ABSOLUTE: - self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) - return - log.debug('Sending (%d bytes #%d) %r as %r...', len(packet), packet_num, out, packet) - for s in self._respond_sockets: - if self._GLOBAL_DONE: - return - try: - if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR - elif not can_send_to(s, addr): - continue - else: - real_addr = addr - bytes_sent = s.sendto(packet, 0, (real_addr, port)) - except OSError as exc: - if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: - # with IPv6 we don't have a reliable way to determine if an interface actually has - # IPV6 support, so we have to try and ignore errors. - continue - # on send errors, log the exception and keep going - self.log_exception_warning('Error sending through socket %d', s.fileno()) - except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions - # on send errors, log the exception and keep going - self.log_exception_warning('Error sending through socket %d', s.fileno()) - else: - if bytes_sent != len(packet): - self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) - - def close(self) -> None: - """Ends the background threads, and prevent this instance from - servicing further queries.""" - if self._GLOBAL_DONE: - return - # remove service listeners - self.remove_all_service_listeners() - self.unregister_all_services() - self._GLOBAL_DONE = True - - # shutdown recv socket and thread - if not self.unicast: - self.engine.del_reader(cast(socket.socket, self._listen_socket)) - cast(socket.socket, self._listen_socket).close() - if self.multi_socket: - for s in self._respond_sockets: - self.engine.del_reader(s) - self.engine.join() - # shutdown the rest - self.notify_all() - for s in self._respond_sockets: - s.close() - - def __enter__(self) -> 'Zeroconf': - return self - - def __exit__( # pylint: disable=useless-return - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - self.close() - return None diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 9a23be93..82e86199 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,16 +26,11 @@ from types import TracebackType # noqa # used in type hints from typing import Awaitable, Callable, Dict, List, Optional, Type, Union -from . import ( - DNSOutgoing, - NotifyListener, - ServiceInfo, - Zeroconf, - _ServiceBrowserBase, - instance_name_from_service_info, -) from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME +from .core import NotifyListener, Zeroconf +from .dns import DNSOutgoing from .exceptions import NonUniqueNameException +from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from .utils.aio import wait_condition_or_timeout from .utils.net import IPVersion, InterfaceChoice, InterfacesType from .utils.time import current_time_millis, millis_to_seconds diff --git a/zeroconf/core.py b/zeroconf/core.py new file mode 100644 index 00000000..8e5ab611 --- /dev/null +++ b/zeroconf/core.py @@ -0,0 +1,878 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import errno +import itertools +import platform +import select +import socket +import threading +from types import TracebackType # noqa # used in type hints +from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union, cast + +from .const import ( + _CACHE_CLEANUP_INTERVAL, + _CHECK_TIME, + _CLASS_IN, + _DNS_OTHER_TTL, + _DNS_PORT, + _FLAGS_AA, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _MAX_MSG_ABSOLUTE, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _REGISTER_TIME, + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_A, + _TYPE_ANY, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, + _UNREGISTER_TIME, +) +from .dns import DNSAddress, DNSCache, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord +from .exceptions import NonUniqueNameException, ServiceNameAlreadyRegistered +from .logger import QuietLogger, log +from .services import RecordUpdateListener, ServiceBrowser, ServiceInfo, instance_name_from_service_info +from .utils.name import service_type_name +from .utils.net import ( + IPVersion, + InterfaceChoice, + InterfacesType, + autodetect_ip_version, + can_send_to, + create_sockets, +) +from .utils.time import current_time_millis, millis_to_seconds + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from . import ServiceListener # pylint: disable=cyclic-import + + +class NotifyListener: + """Receive notifications Zeroconf.notify_all is called.""" + + def notify_all(self) -> None: + """Called when Zeroconf.notify_all is called.""" + raise NotImplementedError() + + +class Engine(threading.Thread): + + """An engine wraps read access to sockets, allowing objects that + need to receive data from sockets to be called back when the + sockets are ready. + + A reader needs a handle_read() method, which is called when the socket + it is interested in is ready for reading. + + Writers are not implemented here, because we only send short + packets. + """ + + def __init__(self, zc: 'Zeroconf') -> None: + threading.Thread.__init__(self) + self.daemon = True + self.zc = zc + self.readers = {} # type: Dict[socket.socket, Listener] + self.timeout = 5 + self.condition = threading.Condition() + self.socketpair = socket.socketpair() + self._last_cache_cleanup = 0.0 + self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) + + def run(self) -> None: + while not self.zc.done: + try: + rr, _wr, _er = select.select([*self.readers.keys(), self.socketpair[0]], [], [], self.timeout) + + if self.zc.done: + return + + for socket_ in rr: + reader = self.readers.get(socket_) + if reader: + reader.handle_read(socket_) + + if self.socketpair[0] in rr: + # Clear the socket's buffer + self.socketpair[0].recv(128) + + except (select.error, socket.error) as e: + # If the socket was closed by another thread, during + # shutdown, ignore it and exit + if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: + raise + + now = current_time_millis() + if now - self._last_cache_cleanup >= _CACHE_CLEANUP_INTERVAL: + self._last_cache_cleanup = now + self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) + self.zc.record_manager.updates_complete() + + self.socketpair[0].close() + self.socketpair[1].close() + + def _notify(self) -> None: + self.condition.notify() + try: + self.socketpair[1].send(b'x') + except socket.error: + # The socketpair may already be closed during shutdown, ignore it + if not self.zc.done: + raise + + def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None: + with self.condition: + self.readers[socket_] = reader + self._notify() + + def del_reader(self, socket_: socket.socket) -> None: + with self.condition: + del self.readers[socket_] + self._notify() + + +class Listener(QuietLogger): + + """A Listener is used by this module to listen on the multicast + group to which DNS messages are sent, allowing the implementation + to cache information as it arrives. + + It requires registration with an Engine object in order to have + the read() method called when a socket is available for reading.""" + + def __init__(self, zc: 'Zeroconf') -> None: + self.zc = zc + self.data = None # type: Optional[bytes] + + def handle_read(self, socket_: socket.socket) -> None: + try: + data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) + except Exception: # pylint: disable=broad-except + self.log_exception_warning('Error reading from socket %d', socket_.fileno()) + return + + if self.data == data: + log.debug( + 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', + addr, + port, + socket_.fileno(), + len(data), + data, + ) + return + + self.data = data + msg = DNSIncoming(data) + if msg.valid: + log.debug( + 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', + addr, + port, + socket_.fileno(), + msg, + len(data), + data, + ) + else: + log.debug( + 'Received from %r:%r (socket %d): (%d bytes) [%r]', + addr, + port, + socket_.fileno(), + len(data), + data, + ) + + if not msg.valid: + pass + + elif msg.is_query(): + # Always multicast responses + if port == _MDNS_PORT: + self.zc.handle_query(msg, None, _MDNS_PORT) + + # If it's not a multicast query, reply via unicast + # and multicast + elif port == _DNS_PORT: + self.zc.handle_query(msg, addr, port) + self.zc.handle_query(msg, None, _MDNS_PORT) + + else: + self.zc.handle_response(msg) + + +class ServiceRegistry: + """A registry to keep track of services. + + This class exists to ensure services can + be safely added and removed with thread + safety. + """ + + def __init__( + self, + ) -> None: + """Create the ServiceRegistry class.""" + self.services = {} # type: Dict[str, ServiceInfo] + self.types = {} # type: Dict[str, List] + self.servers = {} # type: Dict[str, List] + self._lock = threading.Lock() # add and remove services thread safe + + def add(self, info: ServiceInfo) -> None: + """Add a new service to the registry.""" + + with self._lock: + self._add(info) + + def remove(self, info: ServiceInfo) -> None: + """Remove a new service from the registry.""" + + with self._lock: + self._remove(info) + + def update(self, info: ServiceInfo) -> None: + """Update new service in the registry.""" + + with self._lock: + self._remove(info) + self._add(info) + + def get_service_infos(self) -> List[ServiceInfo]: + """Return all ServiceInfo.""" + return list(self.services.values()) + + def get_info_name(self, name: str) -> Optional[ServiceInfo]: + """Return all ServiceInfo for the name.""" + return self.services.get(name) + + def get_types(self) -> List[str]: + """Return all types.""" + return list(self.types.keys()) + + def get_infos_type(self, type_: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching type.""" + return self._get_by_index("types", type_) + + def get_infos_server(self, server: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching server.""" + return self._get_by_index("servers", server) + + def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching the index.""" + service_infos = [] + + for name in getattr(self, attr).get(key, [])[:]: + info = self.services.get(name) + # Since we do not get under a lock since it would be + # a performance issue, its possible + # the service can be unregistered during the get + # so we must check if info is None + if info is not None: + service_infos.append(info) + + return service_infos + + def _add(self, info: ServiceInfo) -> None: + """Add a new service under the lock.""" + lower_name = info.name.lower() + if lower_name in self.services: + raise ServiceNameAlreadyRegistered + + self.services[lower_name] = info + self.types.setdefault(info.type, []).append(lower_name) + self.servers.setdefault(info.server, []).append(lower_name) + + def _remove(self, info: ServiceInfo) -> None: + """Remove a service under the lock.""" + lower_name = info.name.lower() + old_service_info = self.services[lower_name] + self.types[old_service_info.type].remove(lower_name) + self.servers[old_service_info.server].remove(lower_name) + del self.services[lower_name] + + +class QueryHandler: + """Query the ServiceRegistry.""" + + def __init__(self, registry: ServiceRegistry): + """Init the query handler.""" + self.registry = registry + + def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in self.registry.get_types(): + out.add_answer( + msg, + DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, + ), + ) + + def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a PTR query.""" + for service in self.registry.get_infos_type(question.name.lower()): + out.add_answer(msg, service.dns_pointer()) + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + out.add_additional_answer(service.dns_service()) + out.add_additional_answer(service.dns_text()) + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a query any query other then PTR. + + Add answer(s) for A, AAAA, SRV, or TXT queries. + """ + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): + for dns_address in service.dns_addresses(): + out.add_answer(msg, dns_address) + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + return + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer(msg, service.dns_service()) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer(msg, service.dns_text()) + if question.type == _TYPE_SRV: + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: + """Deal with incoming query packets. Provides a response if possible.""" + if unicast: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) + for question in msg.questions: + out.add_question(question) + else: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + + for question in msg.questions: + if question.type == _TYPE_PTR: + if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._answer_service_type_enumeration_query(msg, out) + else: + self._answer_ptr_query(msg, out, question) + continue + + self._answer_non_ptr_query(msg, out, question) + + if out is not None and out.answers: + out.id = msg.id + return out + + return None + + +class RecordManager: + """Process records into the cache and notify listeners.""" + + def __init__(self, zeroconf: 'Zeroconf') -> None: + """Init the record manager.""" + self.zc = zeroconf + self.cache = zeroconf.cache + self.listeners: List[RecordUpdateListener] = [] + + def updates(self, now: float, rec: List[DNSRecord]) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called before the cache is updated. + """ + for listener in self.listeners: + listener.update_records(self.zc, now, rec) + + def updates_complete(self) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called after the cache is updated. + """ + for listener in self.listeners: + listener.update_records_complete() + self.zc.notify_all() + + def updates_from_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + updates: List[DNSRecord] = [] + address_adds: List[DNSAddress] = [] + other_adds: List[DNSRecord] = [] + removes: List[DNSRecord] = [] + now = current_time_millis() + for record in msg.answers: + + updated = True + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): + if entry == record: + updated = False + if record.created - entry.created > 1000 and entry not in msg.answers: + removes.append(entry) + + expired = record.is_expired(now) + maybe_entry = self.cache.get(record) + if not expired: + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + if isinstance(record, DNSAddress): + address_adds.append(record) + else: + other_adds.append(record) + if updated: + updates.append(record) + elif maybe_entry is not None: + updates.append(record) + removes.append(record) + + if not updates and not address_adds and not other_adds and not removes: + return + + self.updates(now, updates) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the context manager to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + self.cache.add_records(itertools.chain(address_adds, other_adds)) + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + self.cache.remove_records(removes) + self.updates_complete() + + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s).""" + self.listeners.append(listener) + + if question is not None: + now = current_time_millis() + records = [] + questions = [question] if isinstance(question, DNSQuestion) else question + for single_question in questions: + for record in self.cache.entries_with_name(single_question.name): + if single_question.answered_by(record) and not record.is_expired(now): + records.append(record) + if records: + listener.update_records(self.zc, now, records) + listener.update_records_complete() + + self.zc.notify_all() + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener.""" + try: + self.listeners.remove(listener) + self.zc.notify_all() + except ValueError as e: + log.exception('Failed to remove listener: %r', e) + + +class Zeroconf(QuietLogger): + + """Implementation of Zeroconf Multicast DNS Service Discovery + + Supports registration, unregistration, queries and browsing. + """ + + def __init__( + self, + interfaces: InterfacesType = InterfaceChoice.All, + unicast: bool = False, + ip_version: Optional[IPVersion] = None, + apple_p2p: bool = False, + ) -> None: + """Creates an instance of the Zeroconf class, establishing + multicast communications, listening and reaping threads. + + :param interfaces: :class:`InterfaceChoice` or a list of IP addresses + (IPv4 and IPv6) and interface indexes (IPv6 only). + + IPv6 notes for non-POSIX systems: + * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` + on Python versions before 3.8. + + Also listening on loopback (``::1``) doesn't work, use a real address. + :param ip_version: IP versions to support. If `choice` is a list, the default is detected + from it. Otherwise defaults to V4 only for backward compatibility. + :param apple_p2p: use AWDL interface (only macOS) + """ + if ip_version is None: + ip_version = autodetect_ip_version(interfaces) + + # hook for threads + self._GLOBAL_DONE = False + self.unicast = unicast + + if apple_p2p and not platform.system() == 'Darwin': + raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') + + self._listen_socket, self._respond_sockets = create_sockets( + interfaces, unicast, ip_version, apple_p2p=apple_p2p + ) + log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) + self.multi_socket = unicast or interfaces is not InterfaceChoice.Default + + self._notify_listeners = [] # type: List[NotifyListener] + self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] + self.registry = ServiceRegistry() + self.query_handler = QueryHandler(self.registry) + self.cache = DNSCache() + self.record_manager = RecordManager(self) + + self.condition = threading.Condition() + + self.engine = Engine(self) + self.listener = Listener(self) + if not unicast: + self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) + if self.multi_socket: + for s in self._respond_sockets: + self.engine.add_reader(self.listener, s) + # Start the engine only after all + # the readers have been added to avoid + # missing any packets that are on the wire + self.engine.start() + + @property + def done(self) -> bool: + return self._GLOBAL_DONE + + @property + def listeners(self) -> List[RecordUpdateListener]: + return self.record_manager.listeners + + def wait(self, timeout: float) -> None: + """Calling thread waits for a given number of milliseconds or + until notified.""" + with self.condition: + self.condition.wait(millis_to_seconds(timeout)) + + def notify_all(self) -> None: + """Notifies all waiting threads""" + with self.condition: + self.condition.notify_all() + for listener in self._notify_listeners: + listener.notify_all() + + def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds.""" + info = ServiceInfo(type_, name) + if info.request(self, timeout): + return info + return None + + def add_notify_listener(self, listener: NotifyListener) -> None: + """Adds a listener to receive notify_all events.""" + self._notify_listeners.append(listener) + + def remove_notify_listener(self, listener: NotifyListener) -> None: + """Removes a listener from the set that is currently listening.""" + self._notify_listeners.remove(listener) + + def add_service_listener(self, type_: str, listener: 'ServiceListener') -> None: + """Adds a listener for a particular service type. This object + will then have its add_service and remove_service methods called when + services of that type become available and unavailable.""" + self.remove_service_listener(listener) + self.browsers[listener] = ServiceBrowser(self, type_, listener) + + def remove_service_listener(self, listener: 'ServiceListener') -> None: + """Removes a listener from the set that is currently listening.""" + if listener in self.browsers: + self.browsers[listener].cancel() + del self.browsers[listener] + + def remove_all_service_listeners(self) -> None: + """Removes a listener from the set that is currently listening.""" + for listener in list(self.browsers): + self.remove_service_listener(listener) + + def register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + if ttl is not None: + # ttl argument is used to maintain backward compatibility + # Setting TTLs via ServiceInfo is preferred + info.host_ttl = ttl + info.other_ttl = ttl + self.check_service(info, allow_name_change, cooperating_responders) + self.registry.add(info) + self._broadcast_service(info, _REGISTER_TIME, None) + + def update_service(self, info: ServiceInfo) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" + + self.registry.update(info) + self._broadcast_service(info, _REGISTER_TIME, None) + + def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + """Send a broadcasts to announce a service at intervals.""" + now = current_time_millis() + next_time = now + i = 0 + while i < 3: + if now < next_time: + self.wait(next_time - now) + now = current_time_millis() + continue + + self.send_service_broadcast(info, ttl) + i += 1 + next_time += interval + + def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: + """Send a broadcast to announce a service.""" + self.send(self.generate_service_broadcast(info, ttl)) + + def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: + """Generate a broadcast to announce a service.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + self._add_broadcast_answer(out, info, ttl) + return out + + def send_service_query(self, info: ServiceInfo) -> None: + """Send a query to lookup a service.""" + self.send(self.generate_service_query(info)) + + def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use + """Generate a query to lookup a service.""" + out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) + out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) + out.add_authorative_answer(info.dns_pointer()) + return out + + def _add_broadcast_answer( # pylint: disable=no-self-use + self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] + ) -> None: + """Add answers to broadcast a service.""" + other_ttl = info.other_ttl if override_ttl is None else override_ttl + host_ttl = info.host_ttl if override_ttl is None else override_ttl + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) + for dns_address in info.dns_addresses(override_ttl=host_ttl): + out.add_answer_at_time(dns_address, 0) + + def unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service.""" + self.registry.remove(info) + self._broadcast_service(info, _UNREGISTER_TIME, 0) + + def unregister_all_services(self) -> None: + """Unregister all registered services.""" + service_infos = self.registry.get_service_infos() + if not service_infos: + return + now = current_time_millis() + next_time = now + i = 0 + while i < 3: + if now < next_time: + self.wait(next_time - now) + now = current_time_millis() + continue + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + for info in service_infos: + self._add_broadcast_answer(out, info, 0) + self.send(out) + i += 1 + next_time += _UNREGISTER_TIME + + def check_service( + self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False + ) -> None: + """Checks the network for a unique service name, modifying the + ServiceInfo passed in if it is not unique.""" + instance_name = instance_name_from_service_info(info) + if cooperating_responders: + return + next_instance_number = 2 + next_time = now = current_time_millis() + i = 0 + while i < 3: + # check for a name conflict + while self.cache.current_entry_with_name_and_alias(info.type, info.name): + if not allow_name_change: + raise NonUniqueNameException + + # change the name and look for a conflict + info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) + next_instance_number += 1 + service_type_name(info.name) + next_time = now + i = 0 + + if now < next_time: + self.wait(next_time - now) + now = current_time_millis() + continue + + self.send_service_query(info) + i += 1 + next_time += _CHECK_TIME + + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s).""" + self.record_manager.add_listener(listener, question) + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener.""" + self.record_manager.remove_listener(listener) + + def handle_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + self.record_manager.updates_from_response(msg) + + def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: + """Deal with incoming query packets. Provides a response if + possible.""" + out = self.query_handler.response(msg, port != _MDNS_PORT) + if out: + self.send(out, addr, port) + + def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + """Sends an outgoing packet.""" + packets = out.packets() + packet_num = 0 + for packet in packets: + packet_num += 1 + if len(packet) > _MAX_MSG_ABSOLUTE: + self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) + return + log.debug('Sending (%d bytes #%d) %r as %r...', len(packet), packet_num, out, packet) + for s in self._respond_sockets: + if self._GLOBAL_DONE: + return + try: + if addr is None: + real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR + elif not can_send_to(s, addr): + continue + else: + real_addr = addr + bytes_sent = s.sendto(packet, 0, (real_addr, port)) + except OSError as exc: + if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: + # with IPv6 we don't have a reliable way to determine if an interface actually has + # IPV6 support, so we have to try and ignore errors. + continue + # on send errors, log the exception and keep going + self.log_exception_warning('Error sending through socket %d', s.fileno()) + except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions + # on send errors, log the exception and keep going + self.log_exception_warning('Error sending through socket %d', s.fileno()) + else: + if bytes_sent != len(packet): + self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) + + def close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries.""" + if self._GLOBAL_DONE: + return + # remove service listeners + self.remove_all_service_listeners() + self.unregister_all_services() + self._GLOBAL_DONE = True + + # shutdown recv socket and thread + if not self.unicast: + self.engine.del_reader(cast(socket.socket, self._listen_socket)) + cast(socket.socket, self._listen_socket).close() + if self.multi_socket: + for s in self._respond_sockets: + self.engine.del_reader(s) + self.engine.join() + # shutdown the rest + self.notify_all() + for s in self._respond_sockets: + s.close() + + def __enter__(self) -> 'Zeroconf': + return self + + def __exit__( # pylint: disable=useless-return + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self.close() + return None diff --git a/zeroconf/services.py b/zeroconf/services.py index 5c61741e..6b47355f 100644 --- a/zeroconf/services.py +++ b/zeroconf/services.py @@ -65,6 +65,16 @@ ) +def instance_name_from_service_info(info: "ServiceInfo") -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + class Signal: def __init__(self) -> None: self._handlers = [] # type: List[Callable[..., None]] From c8a0a71c31252bbc4a242701bc786eb419e1a8e8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 15:28:08 -1000 Subject: [PATCH 239/608] Move ServiceStateChange to zeroconf.services (#548) --- zeroconf/__init__.py | 2 +- zeroconf/services.py | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ba473415..0cfd5029 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -108,6 +108,7 @@ _ServiceBrowserBase, ServiceBrowser, ServiceInfo, + ServiceStateChange, ) from .utils.name import service_type_name # noqa # import needed for backwards compat from .utils.net import ( # noqa # import needed for backwards compat @@ -118,7 +119,6 @@ get_all_addresses_v6, InterfaceChoice, InterfacesType, - ServiceStateChange, IPVersion, _is_v6_address, _encode_address, diff --git a/zeroconf/services.py b/zeroconf/services.py index 6b47355f..b0ad93d6 100644 --- a/zeroconf/services.py +++ b/zeroconf/services.py @@ -20,6 +20,7 @@ USA """ +import enum import socket import threading import warnings @@ -50,7 +51,6 @@ from .utils.name import service_type_name from .utils.net import ( IPVersion, - ServiceStateChange, _encode_address, _is_v6_address, ) @@ -65,6 +65,13 @@ ) +@enum.unique +class ServiceStateChange(enum.Enum): + Added = 1 + Removed = 2 + Updated = 3 + + def instance_name_from_service_info(info: "ServiceInfo") -> str: """Calculate the instance name from the ServiceInfo.""" # This is kind of funky because of the subtype based tests From 4086fb4304b0653153865306e46c865c90137922 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 15:35:49 -1000 Subject: [PATCH 240/608] Move the ServiceRegistry into its own module (#549) --- zeroconf/__init__.py | 3 +- zeroconf/core.py | 93 +------------- .../{services.py => services/__init__.py} | 16 +-- zeroconf/services/registry.py | 118 ++++++++++++++++++ 4 files changed, 130 insertions(+), 100 deletions(-) rename zeroconf/{services.py => services/__init__.py} (98%) create mode 100644 zeroconf/services/registry.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 0cfd5029..c4533bab 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -76,7 +76,7 @@ _TYPE_TXT, _UNREGISTER_TIME, ) -from .core import NotifyListener, ServiceRegistry, Zeroconf # noqa # import needed for backwards compat +from .core import NotifyListener, Zeroconf # noqa # import needed for backwards compat from .dns import ( # noqa # import needed for backwards compat DNSAddress, DNSCache, @@ -110,6 +110,7 @@ ServiceInfo, ServiceStateChange, ) +from .services.registry import ServiceRegistry # noqa # import needed for backwards compat from .utils.name import service_type_name # noqa # import needed for backwards compat from .utils.net import ( # noqa # import needed for backwards compat add_multicast_member, diff --git a/zeroconf/core.py b/zeroconf/core.py index 8e5ab611..2b53f4c6 100644 --- a/zeroconf/core.py +++ b/zeroconf/core.py @@ -52,9 +52,10 @@ _UNREGISTER_TIME, ) from .dns import DNSAddress, DNSCache, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord -from .exceptions import NonUniqueNameException, ServiceNameAlreadyRegistered +from .exceptions import NonUniqueNameException from .logger import QuietLogger, log from .services import RecordUpdateListener, ServiceBrowser, ServiceInfo, instance_name_from_service_info +from .services.registry import ServiceRegistry from .utils.name import service_type_name from .utils.net import ( IPVersion, @@ -226,96 +227,6 @@ def handle_read(self, socket_: socket.socket) -> None: self.zc.handle_response(msg) -class ServiceRegistry: - """A registry to keep track of services. - - This class exists to ensure services can - be safely added and removed with thread - safety. - """ - - def __init__( - self, - ) -> None: - """Create the ServiceRegistry class.""" - self.services = {} # type: Dict[str, ServiceInfo] - self.types = {} # type: Dict[str, List] - self.servers = {} # type: Dict[str, List] - self._lock = threading.Lock() # add and remove services thread safe - - def add(self, info: ServiceInfo) -> None: - """Add a new service to the registry.""" - - with self._lock: - self._add(info) - - def remove(self, info: ServiceInfo) -> None: - """Remove a new service from the registry.""" - - with self._lock: - self._remove(info) - - def update(self, info: ServiceInfo) -> None: - """Update new service in the registry.""" - - with self._lock: - self._remove(info) - self._add(info) - - def get_service_infos(self) -> List[ServiceInfo]: - """Return all ServiceInfo.""" - return list(self.services.values()) - - def get_info_name(self, name: str) -> Optional[ServiceInfo]: - """Return all ServiceInfo for the name.""" - return self.services.get(name) - - def get_types(self) -> List[str]: - """Return all types.""" - return list(self.types.keys()) - - def get_infos_type(self, type_: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching type.""" - return self._get_by_index("types", type_) - - def get_infos_server(self, server: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching server.""" - return self._get_by_index("servers", server) - - def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: - """Return all ServiceInfo matching the index.""" - service_infos = [] - - for name in getattr(self, attr).get(key, [])[:]: - info = self.services.get(name) - # Since we do not get under a lock since it would be - # a performance issue, its possible - # the service can be unregistered during the get - # so we must check if info is None - if info is not None: - service_infos.append(info) - - return service_infos - - def _add(self, info: ServiceInfo) -> None: - """Add a new service under the lock.""" - lower_name = info.name.lower() - if lower_name in self.services: - raise ServiceNameAlreadyRegistered - - self.services[lower_name] = info - self.types.setdefault(info.type, []).append(lower_name) - self.servers.setdefault(info.server, []).append(lower_name) - - def _remove(self, info: ServiceInfo) -> None: - """Remove a service under the lock.""" - lower_name = info.name.lower() - old_service_info = self.services[lower_name] - self.types[old_service_info.type].remove(lower_name) - self.servers[old_service_info.server].remove(lower_name) - del self.services[lower_name] - - class QueryHandler: """Query the ServiceRegistry.""" diff --git a/zeroconf/services.py b/zeroconf/services/__init__.py similarity index 98% rename from zeroconf/services.py rename to zeroconf/services/__init__.py index b0ad93d6..f9c5c87b 100644 --- a/zeroconf/services.py +++ b/zeroconf/services/__init__.py @@ -27,7 +27,7 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from .const import ( +from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, _CLASS_IN, @@ -46,20 +46,20 @@ _TYPE_SRV, _TYPE_TXT, ) -from .dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText -from .exceptions import BadTypeInNameException -from .utils.name import service_type_name -from .utils.net import ( +from ..dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from ..exceptions import BadTypeInNameException +from ..utils.name import service_type_name +from ..utils.net import ( IPVersion, _encode_address, _is_v6_address, ) -from .utils.struct import int2byte -from .utils.time import current_time_millis, millis_to_seconds +from ..utils.struct import int2byte +from ..utils.time import current_time_millis, millis_to_seconds if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from . import ( # pylint: disable=cyclic-import + from .. import ( # pylint: disable=cyclic-import ServiceListener, Zeroconf, ) diff --git a/zeroconf/services/registry.py b/zeroconf/services/registry.py new file mode 100644 index 00000000..19d4ba46 --- /dev/null +++ b/zeroconf/services/registry.py @@ -0,0 +1,118 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import threading +from typing import Dict, List, Optional + + +from ..exceptions import ServiceNameAlreadyRegistered +from ..services import ServiceInfo + + +class ServiceRegistry: + """A registry to keep track of services. + + This class exists to ensure services can + be safely added and removed with thread + safety. + """ + + def __init__( + self, + ) -> None: + """Create the ServiceRegistry class.""" + self.services = {} # type: Dict[str, ServiceInfo] + self.types = {} # type: Dict[str, List] + self.servers = {} # type: Dict[str, List] + self._lock = threading.Lock() # add and remove services thread safe + + def add(self, info: ServiceInfo) -> None: + """Add a new service to the registry.""" + + with self._lock: + self._add(info) + + def remove(self, info: ServiceInfo) -> None: + """Remove a new service from the registry.""" + + with self._lock: + self._remove(info) + + def update(self, info: ServiceInfo) -> None: + """Update new service in the registry.""" + + with self._lock: + self._remove(info) + self._add(info) + + def get_service_infos(self) -> List[ServiceInfo]: + """Return all ServiceInfo.""" + return list(self.services.values()) + + def get_info_name(self, name: str) -> Optional[ServiceInfo]: + """Return all ServiceInfo for the name.""" + return self.services.get(name) + + def get_types(self) -> List[str]: + """Return all types.""" + return list(self.types.keys()) + + def get_infos_type(self, type_: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching type.""" + return self._get_by_index("types", type_) + + def get_infos_server(self, server: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching server.""" + return self._get_by_index("servers", server) + + def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: + """Return all ServiceInfo matching the index.""" + service_infos = [] + + for name in getattr(self, attr).get(key, [])[:]: + info = self.services.get(name) + # Since we do not get under a lock since it would be + # a performance issue, its possible + # the service can be unregistered during the get + # so we must check if info is None + if info is not None: + service_infos.append(info) + + return service_infos + + def _add(self, info: ServiceInfo) -> None: + """Add a new service under the lock.""" + lower_name = info.name.lower() + if lower_name in self.services: + raise ServiceNameAlreadyRegistered + + self.services[lower_name] = info + self.types.setdefault(info.type, []).append(lower_name) + self.servers.setdefault(info.server, []).append(lower_name) + + def _remove(self, info: ServiceInfo) -> None: + """Remove a service under the lock.""" + lower_name = info.name.lower() + old_service_info = self.services[lower_name] + self.types[old_service_info.type].remove(lower_name) + self.servers[old_service_info.server].remove(lower_name) + del self.services[lower_name] From ffdc9887ede1f867c155743b344efc53e0ceee42 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 15:46:21 -1000 Subject: [PATCH 241/608] Move ServiceListener to zeroconf.services (#550) --- zeroconf/__init__.py | 12 +----------- zeroconf/core.py | 22 ++++++++++++---------- zeroconf/services/__init__.py | 16 ++++++++++++---- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index c4533bab..625f7f9d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -108,6 +108,7 @@ _ServiceBrowserBase, ServiceBrowser, ServiceInfo, + ServiceListener, ServiceStateChange, ) from .services.registry import ServiceRegistry # noqa # import needed for backwards compat @@ -159,17 +160,6 @@ # implementation classes -class ServiceListener: - def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - class ZeroconfServiceTypes(ServiceListener): """ Return all of the advertised services on any local networks diff --git a/zeroconf/core.py b/zeroconf/core.py index 2b53f4c6..fd1edc55 100644 --- a/zeroconf/core.py +++ b/zeroconf/core.py @@ -27,7 +27,7 @@ import socket import threading from types import TracebackType # noqa # used in type hints -from typing import Dict, List, Optional, TYPE_CHECKING, Type, Union, cast +from typing import Dict, List, Optional, Type, Union, cast from .const import ( _CACHE_CLEANUP_INTERVAL, @@ -54,7 +54,13 @@ from .dns import DNSAddress, DNSCache, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from .exceptions import NonUniqueNameException from .logger import QuietLogger, log -from .services import RecordUpdateListener, ServiceBrowser, ServiceInfo, instance_name_from_service_info +from .services import ( + RecordUpdateListener, + ServiceBrowser, + ServiceInfo, + ServiceListener, + instance_name_from_service_info, +) from .services.registry import ServiceRegistry from .utils.name import service_type_name from .utils.net import ( @@ -67,10 +73,6 @@ ) from .utils.time import current_time_millis, millis_to_seconds -if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from . import ServiceListener # pylint: disable=cyclic-import - class NotifyListener: """Receive notifications Zeroconf.notify_all is called.""" @@ -481,8 +483,8 @@ def __init__( log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) self.multi_socket = unicast or interfaces is not InterfaceChoice.Default - self._notify_listeners = [] # type: List[NotifyListener] - self.browsers = {} # type: Dict[ServiceListener, ServiceBrowser] + self._notify_listeners: List[NotifyListener] = [] + self.browsers: Dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() self.query_handler = QueryHandler(self.registry) self.cache = DNSCache() @@ -540,14 +542,14 @@ def remove_notify_listener(self, listener: NotifyListener) -> None: """Removes a listener from the set that is currently listening.""" self._notify_listeners.remove(listener) - def add_service_listener(self, type_: str, listener: 'ServiceListener') -> None: + def add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object will then have its add_service and remove_service methods called when services of that type become available and unavailable.""" self.remove_service_listener(listener) self.browsers[listener] = ServiceBrowser(self, type_, listener) - def remove_service_listener(self, listener: 'ServiceListener') -> None: + def remove_service_listener(self, listener: ServiceListener) -> None: """Removes a listener from the set that is currently listening.""" if listener in self.browsers: self.browsers[listener].cancel() diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index f9c5c87b..59526ad0 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -59,10 +59,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from .. import ( # pylint: disable=cyclic-import - ServiceListener, - Zeroconf, - ) + from .. import Zeroconf # pylint: disable=cyclic-import @enum.unique @@ -82,6 +79,17 @@ def instance_name_from_service_info(info: "ServiceInfo") -> str: return info.name[: -len(service_name) - 1] +class ServiceListener: + def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: + raise NotImplementedError() + + class Signal: def __init__(self) -> None: self._handlers = [] # type: List[Callable[..., None]] From 5b489e5b15ff89a0ffc000ccfeab2a8af346a65e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 15:58:58 -1000 Subject: [PATCH 242/608] Move QueryHandler and RecordManager handlers into zeroconf.handlers (#551) --- zeroconf/core.py | 219 +----------------------------------- zeroconf/handlers.py | 258 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 217 deletions(-) create mode 100644 zeroconf/handlers.py diff --git a/zeroconf/core.py b/zeroconf/core.py index fd1edc55..f432c411 100644 --- a/zeroconf/core.py +++ b/zeroconf/core.py @@ -21,7 +21,6 @@ """ import errno -import itertools import platform import select import socket @@ -33,7 +32,6 @@ _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, _CLASS_IN, - _DNS_OTHER_TTL, _DNS_PORT, _FLAGS_AA, _FLAGS_QR_QUERY, @@ -43,16 +41,12 @@ _MDNS_ADDR6, _MDNS_PORT, _REGISTER_TIME, - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_A, - _TYPE_ANY, _TYPE_PTR, - _TYPE_SRV, - _TYPE_TXT, _UNREGISTER_TIME, ) -from .dns import DNSAddress, DNSCache, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord +from .dns import DNSCache, DNSIncoming, DNSOutgoing, DNSQuestion from .exceptions import NonUniqueNameException +from .handlers import QueryHandler, RecordManager from .logger import QuietLogger, log from .services import ( RecordUpdateListener, @@ -229,215 +223,6 @@ def handle_read(self, socket_: socket.socket) -> None: self.zc.handle_response(msg) -class QueryHandler: - """Query the ServiceRegistry.""" - - def __init__(self, registry: ServiceRegistry): - """Init the query handler.""" - self.registry = registry - - def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: - """Provide an answer to a service type enumeration query. - - https://datatracker.ietf.org/doc/html/rfc6763#section-9 - """ - for stype in self.registry.get_types(): - out.add_answer( - msg, - DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - - def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name.lower()): - out.add_answer(msg, service.dns_pointer()) - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer(service.dns_service()) - out.add_additional_answer(service.dns_text()) - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a query any query other then PTR. - - Add answer(s) for A, AAAA, SRV, or TXT queries. - """ - name_to_find = question.name.lower() - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(name_to_find) # type: ignore - if service is None: - return - - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: - """Deal with incoming query packets. Provides a response if possible.""" - if unicast: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) - for question in msg.questions: - out.add_question(question) - else: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._answer_service_type_enumeration_query(msg, out) - else: - self._answer_ptr_query(msg, out, question) - continue - - self._answer_non_ptr_query(msg, out, question) - - if out is not None and out.answers: - out.id = msg.id - return out - - return None - - -class RecordManager: - """Process records into the cache and notify listeners.""" - - def __init__(self, zeroconf: 'Zeroconf') -> None: - """Init the record manager.""" - self.zc = zeroconf - self.cache = zeroconf.cache - self.listeners: List[RecordUpdateListener] = [] - - def updates(self, now: float, rec: List[DNSRecord]) -> None: - """Used to notify listeners of new information that has updated - a record. - - This method must be called before the cache is updated. - """ - for listener in self.listeners: - listener.update_records(self.zc, now, rec) - - def updates_complete(self) -> None: - """Used to notify listeners of new information that has updated - a record. - - This method must be called after the cache is updated. - """ - for listener in self.listeners: - listener.update_records_complete() - self.zc.notify_all() - - def updates_from_response(self, msg: DNSIncoming) -> None: - """Deal with incoming response packets. All answers - are held in the cache, and listeners are notified.""" - updates: List[DNSRecord] = [] - address_adds: List[DNSAddress] = [] - other_adds: List[DNSRecord] = [] - removes: List[DNSRecord] = [] - now = current_time_millis() - for record in msg.answers: - - updated = True - - if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False - if record.created - entry.created > 1000 and entry not in msg.answers: - removes.append(entry) - - expired = record.is_expired(now) - maybe_entry = self.cache.get(record) - if not expired: - if maybe_entry is not None: - maybe_entry.reset_ttl(record) - else: - if isinstance(record, DNSAddress): - address_adds.append(record) - else: - other_adds.append(record) - if updated: - updates.append(record) - elif maybe_entry is not None: - updates.append(record) - removes.append(record) - - if not updates and not address_adds and not other_adds and not removes: - return - - self.updates(now, updates) - # The cache adds must be processed AFTER we trigger - # the updates since we compare existing data - # with the new data and updating the cache - # ahead of update_record will cause listeners - # to miss changes - # - # We must process address adds before non-addresses - # otherwise a fetch of ServiceInfo may miss an address - # because it thinks the cache is complete - # - # The cache is processed under the context manager to ensure - # that any ServiceBrowser that is going to call - # zc.get_service_info will see the cached value - # but ONLY after all the record updates have been - # processsed. - self.cache.add_records(itertools.chain(address_adds, other_adds)) - # Removes are processed last since - # ServiceInfo could generate an un-needed query - # because the data was not yet populated. - self.cache.remove_records(removes) - self.updates_complete() - - def add_listener( - self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] - ) -> None: - """Adds a listener for a given question. The listener will have - its update_record method called when information is available to - answer the question(s).""" - self.listeners.append(listener) - - if question is not None: - now = current_time_millis() - records = [] - questions = [question] if isinstance(question, DNSQuestion) else question - for single_question in questions: - for record in self.cache.entries_with_name(single_question.name): - if single_question.answered_by(record) and not record.is_expired(now): - records.append(record) - if records: - listener.update_records(self.zc, now, records) - listener.update_records_complete() - - self.zc.notify_all() - - def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" - try: - self.listeners.remove(listener) - self.zc.notify_all() - except ValueError as e: - log.exception('Failed to remove listener: %r', e) - - class Zeroconf(QuietLogger): """Implementation of Zeroconf Multicast DNS Service Discovery diff --git a/zeroconf/handlers.py b/zeroconf/handlers.py new file mode 100644 index 00000000..000bc908 --- /dev/null +++ b/zeroconf/handlers.py @@ -0,0 +1,258 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import itertools +from typing import List, Optional, TYPE_CHECKING, Union + +from .const import ( + _CLASS_IN, + _DNS_OTHER_TTL, + _FLAGS_AA, + _FLAGS_QR_RESPONSE, + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_A, + _TYPE_ANY, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) +from .dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord +from .logger import log +from .services import ( + RecordUpdateListener, +) +from .services.registry import ServiceRegistry +from .utils.time import current_time_millis + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from .core import Zeroconf # pylint: disable=cyclic-import + + +class QueryHandler: + """Query the ServiceRegistry.""" + + def __init__(self, registry: ServiceRegistry): + """Init the query handler.""" + self.registry = registry + + def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: + """Provide an answer to a service type enumeration query. + + https://datatracker.ietf.org/doc/html/rfc6763#section-9 + """ + for stype in self.registry.get_types(): + out.add_answer( + msg, + DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, + _TYPE_PTR, + _CLASS_IN, + _DNS_OTHER_TTL, + stype, + ), + ) + + def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a PTR query.""" + for service in self.registry.get_infos_type(question.name.lower()): + out.add_answer(msg, service.dns_pointer()) + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.1. + out.add_additional_answer(service.dns_service()) + out.add_additional_answer(service.dns_text()) + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: + """Answer a query any query other then PTR. + + Add answer(s) for A, AAAA, SRV, or TXT queries. + """ + name_to_find = question.name.lower() + # Answer A record queries for any service addresses we know + if question.type in (_TYPE_A, _TYPE_ANY): + for service in self.registry.get_infos_server(name_to_find): + for dns_address in service.dns_addresses(): + out.add_answer(msg, dns_address) + + service = self.registry.get_info_name(name_to_find) # type: ignore + if service is None: + return + + if question.type in (_TYPE_SRV, _TYPE_ANY): + out.add_answer(msg, service.dns_service()) + if question.type in (_TYPE_TXT, _TYPE_ANY): + out.add_answer(msg, service.dns_text()) + if question.type == _TYPE_SRV: + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) + + def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: + """Deal with incoming query packets. Provides a response if possible.""" + if unicast: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) + for question in msg.questions: + out.add_question(question) + else: + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + + for question in msg.questions: + if question.type == _TYPE_PTR: + if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._answer_service_type_enumeration_query(msg, out) + else: + self._answer_ptr_query(msg, out, question) + continue + + self._answer_non_ptr_query(msg, out, question) + + if out is not None and out.answers: + out.id = msg.id + return out + + return None + + +class RecordManager: + """Process records into the cache and notify listeners.""" + + def __init__(self, zeroconf: 'Zeroconf') -> None: + """Init the record manager.""" + self.zc = zeroconf + self.cache = zeroconf.cache + self.listeners: List[RecordUpdateListener] = [] + + def updates(self, now: float, rec: List[DNSRecord]) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called before the cache is updated. + """ + for listener in self.listeners: + listener.update_records(self.zc, now, rec) + + def updates_complete(self) -> None: + """Used to notify listeners of new information that has updated + a record. + + This method must be called after the cache is updated. + """ + for listener in self.listeners: + listener.update_records_complete() + self.zc.notify_all() + + def updates_from_response(self, msg: DNSIncoming) -> None: + """Deal with incoming response packets. All answers + are held in the cache, and listeners are notified.""" + updates: List[DNSRecord] = [] + address_adds: List[DNSAddress] = [] + other_adds: List[DNSRecord] = [] + removes: List[DNSRecord] = [] + now = current_time_millis() + for record in msg.answers: + + updated = True + + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): + if entry == record: + updated = False + if record.created - entry.created > 1000 and entry not in msg.answers: + removes.append(entry) + + expired = record.is_expired(now) + maybe_entry = self.cache.get(record) + if not expired: + if maybe_entry is not None: + maybe_entry.reset_ttl(record) + else: + if isinstance(record, DNSAddress): + address_adds.append(record) + else: + other_adds.append(record) + if updated: + updates.append(record) + elif maybe_entry is not None: + updates.append(record) + removes.append(record) + + if not updates and not address_adds and not other_adds and not removes: + return + + self.updates(now, updates) + # The cache adds must be processed AFTER we trigger + # the updates since we compare existing data + # with the new data and updating the cache + # ahead of update_record will cause listeners + # to miss changes + # + # We must process address adds before non-addresses + # otherwise a fetch of ServiceInfo may miss an address + # because it thinks the cache is complete + # + # The cache is processed under the context manager to ensure + # that any ServiceBrowser that is going to call + # zc.get_service_info will see the cached value + # but ONLY after all the record updates have been + # processsed. + self.cache.add_records(itertools.chain(address_adds, other_adds)) + # Removes are processed last since + # ServiceInfo could generate an un-needed query + # because the data was not yet populated. + self.cache.remove_records(removes) + self.updates_complete() + + def add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s).""" + self.listeners.append(listener) + + if question is not None: + now = current_time_millis() + records = [] + questions = [question] if isinstance(question, DNSQuestion) else question + for single_question in questions: + for record in self.cache.entries_with_name(single_question.name): + if single_question.answered_by(record) and not record.is_expired(now): + records.append(record) + if records: + listener.update_records(self.zc, now, records) + listener.update_records_complete() + + self.zc.notify_all() + + def remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener.""" + try: + self.listeners.remove(listener) + self.zc.notify_all() + except ValueError as e: + log.exception('Failed to remove listener: %r', e) From e7fb4e5fb2a6b2163b143a63e2a9e8c5d1eca482 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 16:08:56 -1000 Subject: [PATCH 243/608] Add recipe for TYPE_CHECKING to .coveragerc (#552) --- .coveragerc | 4 ++++ zeroconf/services/__init__.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..56ef8a32 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index 59526ad0..92a9b0dc 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from .. import Zeroconf # pylint: disable=cyclic-import + from ..core import Zeroconf # pylint: disable=cyclic-import @enum.unique From e50b62bb633916d5b84df7bcf7a804c9e3ef7fc2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 16:15:39 -1000 Subject: [PATCH 244/608] Move ZeroconfServiceTypes to zeroconf.services.types (#553) --- zeroconf/__init__.py | 60 +--------------------------- zeroconf/services/types.py | 82 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 59 deletions(-) create mode 100644 zeroconf/services/types.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 625f7f9d..8242c4e1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -21,9 +21,6 @@ """ import sys -import time -from typing import Optional, Union -from typing import Set, Tuple # noqa # used in type hints from .const import ( # noqa # import needed for backwards compat _BROWSER_BACKOFF_LIMIT, @@ -112,6 +109,7 @@ ServiceStateChange, ) from .services.registry import ServiceRegistry # noqa # import needed for backwards compat +from .services.types import ZeroconfServiceTypes # noqa # import needed for backwards compat from .utils.name import service_type_name # noqa # import needed for backwards compat from .utils.net import ( # noqa # import needed for backwards compat add_multicast_member, @@ -155,59 +153,3 @@ If you need support for Python 3.5 please use version 0.28.0 ''' ) - - -# implementation classes - - -class ZeroconfServiceTypes(ServiceListener): - """ - Return all of the advertised services on any local networks - """ - - def __init__(self) -> None: - """Keep track of found services in a set.""" - self.found_services = set() # type: Set[str] - - def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - """Service added.""" - self.found_services.add(name) - - def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - """Service updated.""" - - def remove_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: - """Service removed.""" - - @classmethod - def find( - cls, - zc: Optional['Zeroconf'] = None, - timeout: Union[int, float] = 5, - interfaces: InterfacesType = InterfaceChoice.All, - ip_version: Optional[IPVersion] = None, - ) -> Tuple[str, ...]: - """ - Return all of the advertised services on any local networks. - - :param zc: Zeroconf() instance. Pass in if already have an - instance running or if non-default interfaces are needed - :param timeout: seconds to wait for any responses - :param interfaces: interfaces to listen on. - :param ip_version: IP protocol version to use. - :return: tuple of service type strings - """ - local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) - listener = cls() - browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener) - - # wait for responses - time.sleep(timeout) - - browser.cancel() - - # close down anything we opened - if zc is None: - local_zc.close() - - return tuple(sorted(listener.found_services)) diff --git a/zeroconf/services/types.py b/zeroconf/services/types.py new file mode 100644 index 00000000..e27defff --- /dev/null +++ b/zeroconf/services/types.py @@ -0,0 +1,82 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import time +from typing import Optional, Set, Tuple, Union + +from ..const import _SERVICE_TYPE_ENUMERATION_NAME +from ..core import Zeroconf +from ..services import ServiceBrowser, ServiceListener +from ..utils.net import IPVersion, InterfaceChoice, InterfacesType + + +class ZeroconfServiceTypes(ServiceListener): + """ + Return all of the advertised services on any local networks + """ + + def __init__(self) -> None: + """Keep track of found services in a set.""" + self.found_services: Set[str] = set() + + def add_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service added.""" + self.found_services.add(name) + + def update_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service updated.""" + + def remove_service(self, zc: Zeroconf, type_: str, name: str) -> None: + """Service removed.""" + + @classmethod + def find( + cls, + zc: Optional[Zeroconf] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param zc: Zeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = zc or Zeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + browser = ServiceBrowser(local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener) + + # wait for responses + time.sleep(timeout) + + browser.cancel() + + # close down anything we opened + if zc is None: + local_zc.close() + + return tuple(sorted(listener.found_services)) From 3dfda644efef83640e80876e4fe7da10e87b5990 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 16:36:05 -1000 Subject: [PATCH 245/608] Add missing coverage for ipv6 network utils (#555) --- tests/utils/test_net.py | 58 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/utils/test_net.py diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py new file mode 100644 index 00000000..d4b829c2 --- /dev/null +++ b/tests/utils/test_net.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf.utils.net.""" +from unittest.mock import Mock, patch + +import ifaddr +import pytest + +from zeroconf.utils import net as netutils + + +def _generate_mock_adapters(): + mock_lo0 = Mock(spec=ifaddr.Adapter) + mock_lo0.nice_name = "lo0" + mock_lo0.ips = [ifaddr.IP("127.0.0.1", 8, "lo0")] + mock_lo0.index = 0 + mock_eth0 = Mock(spec=ifaddr.Adapter) + mock_eth0.nice_name = "eth0" + mock_eth0.ips = [ifaddr.IP(("2001:db8::", 1, 1), 8, "eth0")] + mock_eth0.index = 1 + mock_eth1 = Mock(spec=ifaddr.Adapter) + mock_eth1.nice_name = "eth1" + mock_eth1.ips = [ifaddr.IP("192.168.1.5", 23, "eth1")] + mock_eth1.index = 2 + mock_vtun0 = Mock(spec=ifaddr.Adapter) + mock_vtun0.nice_name = "vtun0" + mock_vtun0.ips = [ifaddr.IP("169.254.3.2", 16, "vtun0")] + mock_vtun0.index = 3 + return [mock_eth0, mock_lo0, mock_eth1, mock_vtun0] + + +def test_ip6_to_address_and_index(): + """Test we can extract from mocked adapters.""" + adapters = _generate_mock_adapters() + assert netutils.ip6_to_address_and_index(adapters, "2001:db8::") == (('2001:db8::', 1, 1), 1) + with pytest.raises(RuntimeError): + assert netutils.ip6_to_address_and_index(adapters, "2005:db8::") + + +def test_interface_index_to_ip6_address(): + """Test we can extract from mocked adapters.""" + adapters = _generate_mock_adapters() + assert netutils.interface_index_to_ip6_address(adapters, 1) == ('2001:db8::', 1, 1) + with pytest.raises(RuntimeError): + assert netutils.interface_index_to_ip6_address(adapters, 6) + + +def test_ip6_addresses_to_indexes(): + """Test we can extract from mocked adapters.""" + interfaces = [1] + with patch("zeroconf.utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] + + interfaces = ['2001:db8::'] + with patch("zeroconf.utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] From 3d69656c4e5fbd8f90d54826877a04120d5ec951 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 16:58:09 -1000 Subject: [PATCH 246/608] Fix invalid typing in ServiceInfo._set_text (#554) --- zeroconf/services/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index 92a9b0dc..cd84971d 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -586,10 +586,11 @@ def _set_text(self, text: bytes) -> None: strs.append(text[index : index + length]) index += length + key: bytes + value: Optional[bytes] for s in strs: - parts = s.split(b'=', 1) try: - key, value = parts # type: Tuple[bytes, Optional[bytes]] + key, value = s.split(b'=', 1) except ValueError: # No equals sign at all key = s From 715cd9a1d208139862e6d9d718114e1e472efd28 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:00:11 -1000 Subject: [PATCH 247/608] Relocate some of the services tests to test_services (#556) --- tests/__init__.py | 8 + tests/conftest.py | 18 ++ tests/test_aio.py | 18 ++ tests/test_asyncio.py | 19 +- tests/test_core.py | 9 - tests/test_init.py | 502 +---------------------------------------- tests/test_services.py | 501 +++++++++++++++++++++++++++++++++++++++- 7 files changed, 554 insertions(+), 521 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/__init__.py b/tests/__init__.py index 2ef4b15b..f924adf2 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,3 +19,11 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA """ + +from zeroconf.core import Zeroconf +from zeroconf.dns import DNSIncoming + + +def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: + """Inject a DNSIncoming response.""" + zc.handle_response(msg) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..c05c4b9b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" conftest for zeroconf tests. """ + +import threading + +import pytest + + +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads = frozenset(threading.enumerate()) - threads_before + assert not threads diff --git a/tests/test_aio.py b/tests/test_aio.py index b50e5bc7..48a6ccc4 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -6,6 +6,7 @@ import asyncio import socket +import threading import unittest.mock import pytest @@ -23,6 +24,23 @@ from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads_after = frozenset(threading.enumerate()) + non_executor_threads = frozenset( + [ + thread + for thread in threads_after + if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name + ] + ) + threads = non_executor_threads - threads_before + assert not threads + + @pytest.mark.asyncio async def test_async_basic_usage() -> None: """Test we can create and close the instance.""" diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index ee8f8053..bf4d887e 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -4,12 +4,29 @@ """Unit tests for asyncio.py.""" - import pytest +import threading from zeroconf.asyncio import AsyncZeroconf +@pytest.fixture(autouse=True) +def verify_threads_ended(): + """Verify that the threads are not running after the test.""" + threads_before = frozenset(threading.enumerate()) + yield + threads_after = frozenset(threading.enumerate()) + non_executor_threads = frozenset( + [ + thread + for thread in threads_after + if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name + ] + ) + threads = non_executor_threads - threads_before + assert not threads + + @pytest.mark.asyncio async def test_async_basic_usage() -> None: """Test we can create and close the instance.""" diff --git a/tests/test_core.py b/tests/test_core.py index 5535ab59..40c993b1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,15 +20,6 @@ original_logging_level = logging.NOTSET -@pytest.fixture(autouse=True) -def verify_threads_ended(): - """Verify that the threads are not running after the test.""" - threads_before = frozenset(threading.enumerate()) - yield - threads = frozenset(threading.enumerate()) - threads_before - assert not threads - - def setup_module(): global original_logging_level original_logging_level = log.level diff --git a/tests/test_init.py b/tests/test_init.py index 99424a54..b69699cc 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -26,29 +26,20 @@ import zeroconf as r from zeroconf import ( DNSHinfo, - DNSIncoming, DNSText, ServiceBrowser, ServiceInfo, - ServiceStateChange, Zeroconf, ZeroconfServiceTypes, _EXPIRE_REFRESH_TIME_PERCENT, ) +from . import _inject_response + log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET -@pytest.fixture(autouse=True) -def verify_threads_ended(): - """Verify that the threads are not running after the test.""" - threads_before = frozenset(threading.enumerate()) - yield - threads = frozenset(threading.enumerate()) - threads_before - assert not threads - - def setup_module(): global original_logging_level original_logging_level = log.level @@ -60,11 +51,6 @@ def teardown_module(): log.setLevel(original_logging_level) -def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: - """Inject a DNSIncoming response.""" - zc.handle_response(msg) - - @lru_cache(maxsize=None) def has_working_ipv6(): """Return True if if the system can bind an IPv6 address.""" @@ -1703,490 +1689,6 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi zeroconf.close() -class TestServiceInfo(unittest.TestCase): - def test_get_name(self): - """Verify the name accessor can strip the type.""" - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - assert info.get_name() == "name" - - def test_service_info_rejects_non_matching_updates(self): - """Verify records with the wrong name are rejected.""" - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - ttl = 120 - now = r.current_time_millis() - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - # Verify backwards compatiblity with calling with None - info.update_record(zc, now, None) - # Matching updates - info.update_record( - zc, - now, - r.DNSText( - service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - info.update_record( - zc, - now, - r.DNSService( - service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - 'ASH-2.local.', - ), - ) - assert info.server_key == 'ash-2.local.' - assert info.server == 'ASH-2.local.' - new_address = socket.inet_aton("10.0.1.3") - info.update_record( - zc, - now, - r.DNSAddress( - 'ASH-2.local.', - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - new_address, - ), - ) - assert new_address in info.addresses - # Non-matching updates - info.update_record( - zc, - now, - r.DNSText( - "incorrect.name.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - info.update_record( - zc, - now, - r.DNSService( - "incorrect.name.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - 'ASH-2.local.', - ), - ) - assert info.server_key == 'ash-2.local.' - assert info.server == 'ASH-2.local.' - new_address = socket.inet_aton("10.0.1.4") - info.update_record( - zc, - now, - r.DNSAddress( - "incorrect.name.", - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - new_address, - ), - ) - assert new_address not in info.addresses - zc.close() - - def test_service_info_rejects_expired_records(self): - """Verify records that are expired are rejected.""" - zc = r.Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - ttl = 120 - now = r.current_time_millis() - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - # Matching updates - info.update_record( - zc, - now, - r.DNSText( - service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - # Expired record - expired_record = r.DNSText( - service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', - ) - expired_record.created = 1000 - expired_record._expiration_time = 1000 - info.update_record(zc, now, expired_record) - assert info.properties[b"ci"] == b"2" - zc.close() - - def test_get_info_partial(self): - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - - service_info = None - send_event = Event() - service_info_event = Event() - - last_sent = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal last_sent - - last_sent = out - send_event.set() - - # monkey patch the zeroconf send - setattr(zc, "send", send) - - def mock_incoming_msg(records) -> r.DNSIncoming: - - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - for record in records: - generated.add_answer_at_time(record, 0) - - return r.DNSIncoming(generated.packet()) - - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() - - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for SRV, A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text)] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 3 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSService( - service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 2 - assert r.DNSQuestion(service_server, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_server, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions - last_sent = None - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None - - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() - - def test_get_info_single(self): - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - - service_info = None - send_event = Event() - service_info_event = Event() - - last_sent = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal last_sent - - last_sent = out - send_event.set() - - # monkey patch the zeroconf send - setattr(zc, "send", send) - - def mock_incoming_msg(records) -> r.DNSIncoming: - - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - for record in records: - generated.add_answer_at_time(record, 0) - - return r.DNSIncoming(generated.packet()) - - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() - - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSText( - service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text - ), - r.DNSService( - service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ), - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ), - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None - - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() - - -class TestServiceBrowserMultipleTypes(unittest.TestCase): - def test_update_record(self): - - service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] - service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] - - service_added_count = 0 - service_removed_count = 0 - service_add_event = Event() - service_removed_event = Event() - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal service_added_count - service_added_count += 1 - if service_added_count == 3: - service_add_event.set() - - def remove_service(self, zc, type_, name) -> None: - nonlocal service_removed_count - service_removed_count += 1 - if service_removed_count == 3: - service_removed_event.set() - - def mock_incoming_msg( - service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int - ) -> r.DNSIncoming: - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 - ) - return r.DNSIncoming(generated.packet()) - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) - - try: - wait_time = 3 - - # all three services added - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), - ) - zeroconf.wait(100) - - called_with_refresh_time_check = False - - def _mock_get_expiration_time(self, percent): - nonlocal called_with_refresh_time_check - if percent == _EXPIRE_REFRESH_TIME_PERCENT: - called_with_refresh_time_check = True - return 0 - return self.created + (percent * self.ttl * 10) - - # Set an expire time that will force a refresh - with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), - ) - # Add the last record after updating the first one - # to ensure the service_add_event only gets set - # after the update - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), - ) - service_add_event.wait(wait_time) - assert called_with_refresh_time_check is True - assert service_added_count == 3 - assert service_removed_count == 0 - - # all three services removed - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), - ) - service_removed_event.wait(wait_time) - assert service_added_count == 3 - assert service_removed_count == 3 - - finally: - assert len(zeroconf.listeners) == 1 - service_browser.cancel() - assert len(zeroconf.listeners) == 0 - zeroconf.remove_all_service_listeners() - zeroconf.close() - - def test_multiple_addresses(): type_ = "_http._tcp.local." registration_name = "xxxyyy.%s" % type_ diff --git a/tests/test_services.py b/tests/test_services.py index d931d5c0..7cf476b4 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -8,30 +8,25 @@ import socket import threading import time +import unittest from threading import Event import pytest import zeroconf as r import zeroconf.services as s -from zeroconf import ( +from zeroconf.core import Zeroconf +from zeroconf.services import ( ServiceBrowser, ServiceInfo, ServiceStateChange, - Zeroconf, ) -log = logging.getLogger('zeroconf') -original_logging_level = logging.NOTSET +from . import _inject_response -@pytest.fixture(autouse=True) -def verify_threads_ended(): - """Verify that the threads are not running after the test.""" - threads_before = frozenset(threading.enumerate()) - yield - threads = frozenset(threading.enumerate()) - threads_before - assert not threads +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET def setup_module(): @@ -45,6 +40,490 @@ def teardown_module(): log.setLevel(original_logging_level) +class TestServiceInfo(unittest.TestCase): + def test_get_name(self): + """Verify the name accessor can strip the type.""" + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + assert info.get_name() == "name" + + def test_service_info_rejects_non_matching_updates(self): + """Verify records with the wrong name are rejected.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Verify backwards compatiblity with calling with None + info.update_record(zc, now, None) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.3") + info.update_record( + zc, + now, + r.DNSAddress( + 'ASH-2.local.', + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address in info.addresses + # Non-matching updates + info.update_record( + zc, + now, + r.DNSText( + "incorrect.name.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + "incorrect.name.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.4") + info.update_record( + zc, + now, + r.DNSAddress( + "incorrect.name.", + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address not in info.addresses + zc.close() + + def test_service_info_rejects_expired_records(self): + """Verify records that are expired are rejected.""" + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + # Expired record + expired_record = r.DNSText( + service_name, + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ) + expired_record.created = 1000 + expired_record._expiration_time = 1000 + info.update_record(zc, now, expired_record) + assert info.properties[b"ci"] == b"2" + zc.close() + + def test_get_info_partial(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # monkey patch the zeroconf send + setattr(zc, "send", send) + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packet()) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for SRV, A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text)] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 3 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 2 + assert r.DNSQuestion(service_server, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + last_sent = None + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + def test_get_info_single(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # monkey patch the zeroconf send + setattr(zc, "send", send) + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packet()) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text + ), + r.DNSService( + service_name, + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + +class TestServiceBrowserMultipleTypes(unittest.TestCase): + def test_update_record(self): + + service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] + service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] + + service_added_count = 0 + service_removed_count = 0 + service_add_event = Event() + service_removed_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + if service_added_count == 3: + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + if service_removed_count == 3: + service_removed_event.set() + + def mock_incoming_msg( + service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int + ) -> r.DNSIncoming: + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + ) + return r.DNSIncoming(generated.packet()) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) + + try: + wait_time = 3 + + # all three services added + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), + ) + zeroconf.wait(100) + + called_with_refresh_time_check = False + + def _mock_get_expiration_time(self, percent): + nonlocal called_with_refresh_time_check + if percent == r._EXPIRE_REFRESH_TIME_PERCENT: + called_with_refresh_time_check = True + return 0 + return self.created + (percent * self.ttl * 10) + + # Set an expire time that will force a refresh + with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + # Add the last record after updating the first one + # to ensure the service_add_event only gets set + # after the update + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), + ) + service_add_event.wait(wait_time) + assert called_with_refresh_time_check is True + assert service_added_count == 3 + assert service_removed_count == 0 + + # all three services removed + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), + ) + service_removed_event.wait(wait_time) + assert service_added_count == 3 + assert service_removed_count == 3 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + def test_backoff(): got_query = Event() From f0d99e2e68791376a8517254338c708a3244f178 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:13:01 -1000 Subject: [PATCH 248/608] Relocate dns tests to test_dns (#557) --- tests/test_dns.py | 416 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_init.py | 387 ----------------------------------------- 2 files changed, 416 insertions(+), 387 deletions(-) create mode 100644 tests/test_dns.py diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 00000000..10ab36b2 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf.py """ + +import copy +import logging +import socket +import struct +import time +import unittest +import unittest.mock +from typing import Dict, cast # noqa # used in type hints + +import zeroconf as r +from zeroconf import ( + DNSHinfo, + DNSText, + ServiceInfo, +) + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestDunder(unittest.TestCase): + def test_dns_text_repr(self): + # There was an issue on Python 3 that prevented DNSText's repr + # from working when the text was longer than 10 bytes + text = DNSText('irrelevant', 0, 0, 0, b'12345678901') + repr(text) + + text = DNSText('irrelevant', 0, 0, 0, b'123') + repr(text) + + def test_dns_hinfo_repr_eq(self): + hinfo = DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os') + assert hinfo == hinfo + repr(hinfo) + + def test_dns_pointer_repr(self): + pointer = r.DNSPointer('irrelevant', r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, '123') + repr(pointer) + + def test_dns_address_repr(self): + address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + assert repr(address).endswith("b'a'") + + address_ipv4 = r.DNSAddress( + 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') + ) + assert repr(address_ipv4).endswith('127.0.0.1') + + address_ipv6 = r.DNSAddress( + 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') + ) + assert repr(address_ipv6).endswith('::1') + + def test_dns_question_repr(self): + question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE) + repr(question) + assert not question != question + + def test_dns_service_repr(self): + service = r.DNSService('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, 'a') + repr(service) + + def test_dns_record_abc(self): + record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + self.assertRaises(r.AbstractMethodException, record.__eq__, record) + self.assertRaises(r.AbstractMethodException, record.write, None) + + def test_dns_record_reset_ttl(self): + record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + time.sleep(1) + record2 = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + now = r.current_time_millis() + + assert record.created != record2.created + assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now) + + record.reset_ttl(record2) + + assert record.ttl == record2.ttl + assert record.created == record2.created + assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now) + + def test_service_info_dunder(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + b'', + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + assert not info != info + repr(info) + + def test_service_info_text_properties_not_given(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + info = ServiceInfo( + type_=type_, + name=registration_name, + addresses=[socket.inet_aton("10.0.1.2")], + port=80, + server="ash-2.local.", + ) + + assert isinstance(info.text, bytes) + repr(info) + + def test_dns_outgoing_repr(self): + dns_outgoing = r.DNSOutgoing(r._FLAGS_QR_QUERY) + repr(dns_outgoing) + + +class PacketGeneration(unittest.TestCase): + def test_parse_own_packet_simple(self): + generated = r.DNSOutgoing(0) + r.DNSIncoming(generated.packet()) + + def test_parse_own_packet_simple_unicast(self): + generated = r.DNSOutgoing(0, False) + r.DNSIncoming(generated.packet()) + + def test_parse_own_packet_flags(self): + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + r.DNSIncoming(generated.packet()) + + def test_parse_own_packet_question(self): + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated.add_question(r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)) + r.DNSIncoming(generated.packet()) + + def test_parse_own_packet_response(self): + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packet()) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_match_question(self): + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + parsed = r.DNSIncoming(generated.packet()) + assert len(generated.questions) == 1 + assert len(generated.questions) == len(parsed.questions) + assert question == parsed.questions[0] + + def test_suppress_answer(self): + query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + query_generated.add_question(question) + answer1 = r.DNSService( + "testname1.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + staleanswer2 = r.DNSService( + "testname2.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL / 2, + 0, + 0, + 80, + "foo.local.", + ) + answer2 = r.DNSService( + "testname2.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + query_generated.add_answer_at_time(answer1, 0) + query_generated.add_answer_at_time(staleanswer2, 0) + query = r.DNSIncoming(query_generated.packet()) + + # Should be suppressed + response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + response.add_answer(query, answer1) + assert len(response.answers) == 0 + + # Should not be suppressed, TTL in query is too short + response.add_answer(query, answer2) + assert len(response.answers) == 1 + + # Should not be suppressed, name is different + tmp = copy.copy(answer1) + tmp.key = "testname3.local." + tmp.name = "testname3.local." + response.add_answer(query, tmp) + assert len(response.answers) == 2 + + # Should not be suppressed, type is different + tmp = copy.copy(answer1) + tmp.type = r._TYPE_A + response.add_answer(query, tmp) + assert len(response.answers) == 3 + + # Should not be suppressed, class is different + tmp = copy.copy(answer1) + tmp.class_ = r._CLASS_NONE + response.add_answer(query, tmp) + assert len(response.answers) == 4 + + # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService + + def test_dns_hinfo(self): + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) + parsed = r.DNSIncoming(generated.packet()) + answer = cast(r.DNSHinfo, parsed.answers[0]) + assert answer.cpu == u'cpu' + assert answer.os == u'os' + + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) + self.assertRaises(r.NamePartTooLongException, generated.packet) + + def test_many_questions(self): + """Test many questions get seperated into multiple packets.""" + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + questions = [] + for i in range(100): + question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 100 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) < r._MAX_MSG_TYPICAL + assert len(packets[1]) < r._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 85 + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 15 + + def test_only_one_answer_can_by_large(self): + """Test that only the first answer in each packet can be large. + + https://datatracker.ietf.org/doc/html/rfc6762#section-17 + """ + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + query = r.DNSIncoming(r.DNSOutgoing(r._FLAGS_QR_QUERY).packet()) + for i in range(3): + generated.add_answer( + query, + r.DNSText( + "zoom._hap._tcp.local.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + 1200, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, + ), + ) + generated.add_answer( + query, + r.DNSService( + "testname1.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + ) + assert len(generated.answers) == 4 + + packets = generated.packets() + assert len(packets) == 4 + assert len(packets[0]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[0]) > r._MAX_MSG_TYPICAL + + assert len(packets[1]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[1]) > r._MAX_MSG_TYPICAL + + assert len(packets[2]) <= r._MAX_MSG_ABSOLUTE + assert len(packets[2]) > r._MAX_MSG_TYPICAL + + assert len(packets[3]) <= r._MAX_MSG_TYPICAL + + for packet in packets: + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + def test_questions_do_not_end_up_every_packet(self): + """Test that questions are not sent again when multiple packets are needed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + Sometimes a Multicast DNS querier will already have too many answers + to fit in the Known-Answer Section of its query packets.... It MUST + immediately follow the packet with another query packet containing no + questions and as many more Known-Answer records as will fit. + """ + + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + for i in range(35): + question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + generated.add_question(question) + answer = r.DNSService( + f"testname{i}.local.", + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + 80, + f"foo{i}.local.", + ) + generated.add_answer_at_time(answer, 0) + + assert len(generated.questions) == 35 + assert len(generated.answers) == 35 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) <= r._MAX_MSG_TYPICAL + assert len(packets[1]) <= r._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 35 + assert len(parsed1.answers) == 33 + + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 2 + + +class PacketForm(unittest.TestCase): + def test_transaction_id(self): + """ID must be zero in a DNS-SD packet""" + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + bytes = generated.packet() + id = bytes[0] << 8 | bytes[1] + assert id == 0 + + def test_query_header_bits(self): + generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + bytes = generated.packet() + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x0 + + def test_response_header_bits(self): + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + bytes = generated.packet() + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x8000 + + def test_numbers(self): + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + bytes = generated.packet() + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 0 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 + + def test_numbers_questions(self): + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + for i in range(10): + generated.add_question(question) + bytes = generated.packet() + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 10 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 diff --git a/tests/test_init.py b/tests/test_init.py index b69699cc..6d8aca8d 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -4,14 +4,10 @@ """ Unit tests for zeroconf.py """ -import copy import errno -import itertools import logging import os import socket -import struct -import threading import time import unittest import unittest.mock @@ -25,13 +21,11 @@ import zeroconf as r from zeroconf import ( - DNSHinfo, DNSText, ServiceBrowser, ServiceInfo, Zeroconf, ZeroconfServiceTypes, - _EXPIRE_REFRESH_TIME_PERCENT, ) from . import _inject_response @@ -79,387 +73,6 @@ def _clear_cache(zc): zc.cache.remove(record) -class TestDunder(unittest.TestCase): - def test_dns_text_repr(self): - # There was an issue on Python 3 that prevented DNSText's repr - # from working when the text was longer than 10 bytes - text = DNSText('irrelevant', 0, 0, 0, b'12345678901') - repr(text) - - text = DNSText('irrelevant', 0, 0, 0, b'123') - repr(text) - - def test_dns_hinfo_repr_eq(self): - hinfo = DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os') - assert hinfo == hinfo - repr(hinfo) - - def test_dns_pointer_repr(self): - pointer = r.DNSPointer('irrelevant', r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, '123') - repr(pointer) - - def test_dns_address_repr(self): - address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - assert repr(address).endswith("b'a'") - - address_ipv4 = r.DNSAddress( - 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') - ) - assert repr(address_ipv4).endswith('127.0.0.1') - - address_ipv6 = r.DNSAddress( - 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') - ) - assert repr(address_ipv6).endswith('::1') - - def test_dns_question_repr(self): - question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE) - repr(question) - assert not question != question - - def test_dns_service_repr(self): - service = r.DNSService('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, 'a') - repr(service) - - def test_dns_record_abc(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - self.assertRaises(r.AbstractMethodException, record.__eq__, record) - self.assertRaises(r.AbstractMethodException, record.write, None) - - def test_dns_record_reset_ttl(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - time.sleep(1) - record2 = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) - now = r.current_time_millis() - - assert record.created != record2.created - assert record.get_remaining_ttl(now) != record2.get_remaining_ttl(now) - - record.reset_ttl(record2) - - assert record.ttl == record2.ttl - assert record.created == record2.created - assert record.get_remaining_ttl(now) == record2.get_remaining_ttl(now) - - def test_service_info_dunder(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - b'', - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - - assert not info != info - repr(info) - - def test_service_info_text_properties_not_given(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - info = ServiceInfo( - type_=type_, - name=registration_name, - addresses=[socket.inet_aton("10.0.1.2")], - port=80, - server="ash-2.local.", - ) - - assert isinstance(info.text, bytes) - repr(info) - - def test_dns_outgoing_repr(self): - dns_outgoing = r.DNSOutgoing(r._FLAGS_QR_QUERY) - repr(dns_outgoing) - - -class PacketGeneration(unittest.TestCase): - def test_parse_own_packet_simple(self): - generated = r.DNSOutgoing(0) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_simple_unicast(self): - generated = r.DNSOutgoing(0, False) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_flags(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - generated.add_question(r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)) - r.DNSIncoming(generated.packet()) - - def test_parse_own_packet_response(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSService( - "æøå.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - 0, - ) - parsed = r.DNSIncoming(generated.packet()) - assert len(generated.answers) == 1 - assert len(generated.answers) == len(parsed.answers) - - def test_match_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - parsed = r.DNSIncoming(generated.packet()) - assert len(generated.questions) == 1 - assert len(generated.questions) == len(parsed.questions) - assert question == parsed.questions[0] - - def test_suppress_answer(self): - query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - query_generated.add_question(question) - answer1 = r.DNSService( - "testname1.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ) - staleanswer2 = r.DNSService( - "testname2.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL / 2, - 0, - 0, - 80, - "foo.local.", - ) - answer2 = r.DNSService( - "testname2.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ) - query_generated.add_answer_at_time(answer1, 0) - query_generated.add_answer_at_time(staleanswer2, 0) - query = r.DNSIncoming(query_generated.packet()) - - # Should be suppressed - response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - response.add_answer(query, answer1) - assert len(response.answers) == 0 - - # Should not be suppressed, TTL in query is too short - response.add_answer(query, answer2) - assert len(response.answers) == 1 - - # Should not be suppressed, name is different - tmp = copy.copy(answer1) - tmp.key = "testname3.local." - tmp.name = "testname3.local." - response.add_answer(query, tmp) - assert len(response.answers) == 2 - - # Should not be suppressed, type is different - tmp = copy.copy(answer1) - tmp.type = r._TYPE_A - response.add_answer(query, tmp) - assert len(response.answers) == 3 - - # Should not be suppressed, class is different - tmp = copy.copy(answer1) - tmp.class_ = r._CLASS_NONE - response.add_answer(query, tmp) - assert len(response.answers) == 4 - - # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService - - def test_dns_hinfo(self): - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) - parsed = r.DNSIncoming(generated.packet()) - answer = cast(r.DNSHinfo, parsed.answers[0]) - assert answer.cpu == u'cpu' - assert answer.os == u'os' - - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) - self.assertRaises(r.NamePartTooLongException, generated.packet) - - def test_many_questions(self): - """Test many questions get seperated into multiple packets.""" - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - questions = [] - for i in range(100): - question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - questions.append(question) - assert len(generated.questions) == 100 - - packets = generated.packets() - assert len(packets) == 2 - assert len(packets[0]) < r._MAX_MSG_TYPICAL - assert len(packets[1]) < r._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert len(parsed1.questions) == 85 - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 15 - - def test_only_one_answer_can_by_large(self): - """Test that only the first answer in each packet can be large. - - https://datatracker.ietf.org/doc/html/rfc6762#section-17 - """ - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - query = r.DNSIncoming(r.DNSOutgoing(r._FLAGS_QR_QUERY).packet()) - for i in range(3): - generated.add_answer( - query, - r.DNSText( - "zoom._hap._tcp.local.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - 1200, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, - ), - ) - generated.add_answer( - query, - r.DNSService( - "testname1.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - ) - assert len(generated.answers) == 4 - - packets = generated.packets() - assert len(packets) == 4 - assert len(packets[0]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[0]) > r._MAX_MSG_TYPICAL - - assert len(packets[1]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[1]) > r._MAX_MSG_TYPICAL - - assert len(packets[2]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[2]) > r._MAX_MSG_TYPICAL - - assert len(packets[3]) <= r._MAX_MSG_TYPICAL - - for packet in packets: - parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 - - def test_questions_do_not_end_up_every_packet(self): - """Test that questions are not sent again when multiple packets are needed. - - https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - Sometimes a Multicast DNS querier will already have too many answers - to fit in the Known-Answer Section of its query packets.... It MUST - immediately follow the packet with another query packet containing no - questions and as many more Known-Answer records as will fit. - """ - - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - for i in range(35): - question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) - generated.add_question(question) - answer = r.DNSService( - f"testname{i}.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - 80, - f"foo{i}.local.", - ) - generated.add_answer_at_time(answer, 0) - - assert len(generated.questions) == 35 - assert len(generated.answers) == 35 - - packets = generated.packets() - assert len(packets) == 2 - assert len(packets[0]) <= r._MAX_MSG_TYPICAL - assert len(packets[1]) <= r._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert len(parsed1.questions) == 35 - assert len(parsed1.answers) == 33 - - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 0 - assert len(parsed2.answers) == 2 - - -class PacketForm(unittest.TestCase): - def test_transaction_id(self): - """ID must be zero in a DNS-SD packet""" - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - bytes = generated.packet() - id = bytes[0] << 8 | bytes[1] - assert id == 0 - - def test_query_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - bytes = generated.packet() - flags = bytes[2] << 8 | bytes[3] - assert flags == 0x0 - - def test_response_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - bytes = generated.packet() - flags = bytes[2] << 8 | bytes[3] - assert flags == 0x8000 - - def test_numbers(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - bytes = generated.packet() - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - assert num_questions == 0 - assert num_answers == 0 - assert num_authorities == 0 - assert num_additionals == 0 - - def test_numbers_questions(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) - for i in range(10): - generated.add_question(question) - bytes = generated.packet() - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - assert num_questions == 10 - assert num_answers == 0 - assert num_authorities == 0 - assert num_additionals == 0 - - class Names(unittest.TestCase): def test_long_name(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) From 18b9d0a8bd07c0a0d2923763a5f131905c31e0df Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:16:37 -1000 Subject: [PATCH 249/608] Relocate additional dns tests to test_dns (#558) --- tests/test_dns.py | 220 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_init.py | 220 --------------------------------------------- 2 files changed, 220 insertions(+), 220 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 10ab36b2..8117339d 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -414,3 +414,223 @@ def test_numbers_questions(self): assert num_answers == 0 assert num_authorities == 0 assert num_additionals == 0 + + +def test_dns_compression_rollback_for_corruption(): + """Verify rolling back does not lead to dns compression corruption.""" + out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + address = socket.inet_pton(socket.AF_INET, "192.168.208.5") + + additionals = [ + { + "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", + "address": address, + "port": 51832, + "text": b"\x13md=HASS Bridge" + b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=L0m/aQ==", + }, + { + "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", + "address": address, + "port": 51834, + "text": b"\x13md=HASS Bridge" + b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b2CnzQ==", + }, + { + "name": "Master Bed TV CEDB27._hap._tcp.local.", + "address": address, + "port": 51830, + "text": b"\x10md=Master Bed" + b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=CVj1kw==", + }, + { + "name": "Living Room TV 921B77._hap._tcp.local.", + "address": address, + "port": 51833, + "text": b"\x11md=Living Room" + b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=qU77SQ==", + }, + { + "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", + "address": address, + "port": 51829, + "text": b"\x13md=HASS Bridge" + b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b0QZlg==", + }, + { + "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", + "address": address, + "port": 51837, + "text": b"\x13md=HASS Bridge" + b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=ahAISA==", + }, + { + "name": "FrontdoorCamera 8941D1._hap._tcp.local.", + "address": address, + "port": 54898, + "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" + b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", + }, + { + "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + "address": address, + "port": 51836, + "text": b"\x13md=HASS Bridge" + b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=6fLM5A==", + }, + { + "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", + "address": address, + "port": 51838, + "text": b"\x13md=HASS Bridge" + b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=u3bdfw==", + }, + { + "name": "Snooze Room TV 6B89B0._hap._tcp.local.", + "address": address, + "port": 51835, + "text": b"\x11md=Snooze Room" + b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=xNTqsg==", + }, + { + "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", + "address": address, + "port": 54811, + "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" + b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", + }, + { + "name": "HASS Bridge OS95 39C053._hap._tcp.local.", + "address": address, + "port": 51831, + "text": b"\x13md=HASS Bridge" + b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" + b"\x04sf=0\x0bsh=Xfe5LQ==", + }, + ] + + out.add_answer_at_time( + DNSText( + "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for record in additionals: + out.add_additional_answer( + r.DNSService( + record["name"], # type: ignore + r._TYPE_SRV, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + 0, + 0, + record["port"], # type: ignore + record["name"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSText( + record["name"], # type: ignore + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + record["text"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSAddress( + record["name"], # type: ignore + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_HOST_TTL, + record["address"], # type: ignore + ) + ) + + for packet in out.packets(): + # Verify we can process the packets we created to + # ensure there is no corruption with the dns compression + incoming = r.DNSIncoming(packet) + assert incoming.valid is True + + +def test_tc_bit_in_query_packet(): + """Verify the TC bit is set when known answers exceed the packet size.""" + out = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, r._TYPE_PTR, r._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.valid is True + + +def test_tc_bit_not_set_in_answer_packet(): + """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" + out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + r._TYPE_TXT, + r._CLASS_IN | r._CLASS_UNIQUE, + r._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.flags & r._FLAGS_TC == 0 + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.flags & r._FLAGS_TC == 0 + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.valid is True diff --git a/tests/test_init.py b/tests/test_init.py index 6d8aca8d..159fd032 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1406,226 +1406,6 @@ def test_ptr_optimization(): zc.close() -def test_dns_compression_rollback_for_corruption(): - """Verify rolling back does not lead to dns compression corruption.""" - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) - address = socket.inet_pton(socket.AF_INET, "192.168.208.5") - - additionals = [ - { - "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", - "address": address, - "port": 51832, - "text": b"\x13md=HASS Bridge" - b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=L0m/aQ==", - }, - { - "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", - "address": address, - "port": 51834, - "text": b"\x13md=HASS Bridge" - b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=b2CnzQ==", - }, - { - "name": "Master Bed TV CEDB27._hap._tcp.local.", - "address": address, - "port": 51830, - "text": b"\x10md=Master Bed" - b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=CVj1kw==", - }, - { - "name": "Living Room TV 921B77._hap._tcp.local.", - "address": address, - "port": 51833, - "text": b"\x11md=Living Room" - b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=qU77SQ==", - }, - { - "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", - "address": address, - "port": 51829, - "text": b"\x13md=HASS Bridge" - b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=b0QZlg==", - }, - { - "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", - "address": address, - "port": 51837, - "text": b"\x13md=HASS Bridge" - b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=ahAISA==", - }, - { - "name": "FrontdoorCamera 8941D1._hap._tcp.local.", - "address": address, - "port": 54898, - "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" - b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", - }, - { - "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", - "address": address, - "port": 51836, - "text": b"\x13md=HASS Bridge" - b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=6fLM5A==", - }, - { - "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", - "address": address, - "port": 51838, - "text": b"\x13md=HASS Bridge" - b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=u3bdfw==", - }, - { - "name": "Snooze Room TV 6B89B0._hap._tcp.local.", - "address": address, - "port": 51835, - "text": b"\x11md=Snooze Room" - b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=xNTqsg==", - }, - { - "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", - "address": address, - "port": 54811, - "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" - b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", - }, - { - "name": "HASS Bridge OS95 39C053._hap._tcp.local.", - "address": address, - "port": 51831, - "text": b"\x13md=HASS Bridge" - b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" - b"\x04sf=0\x0bsh=Xfe5LQ==", - }, - ] - - out.add_answer_at_time( - DNSText( - "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - for record in additionals: - out.add_additional_answer( - r.DNSService( - record["name"], # type: ignore - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - 0, - 0, - record["port"], # type: ignore - record["name"], # type: ignore - ) - ) - out.add_additional_answer( - r.DNSText( - record["name"], # type: ignore - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, - record["text"], # type: ignore - ) - ) - out.add_additional_answer( - r.DNSAddress( - record["name"], # type: ignore - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, - record["address"], # type: ignore - ) - ) - - for packet in out.packets(): - # Verify we can process the packets we created to - # ensure there is no corruption with the dns compression - incoming = r.DNSIncoming(packet) - assert incoming.valid is True - - -def test_tc_bit_in_query_packet(): - """Verify the TC bit is set when known answers exceed the packet size.""" - out = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - type_ = "_hap._tcp.local." - out.add_question(r.DNSQuestion(type_, r._TYPE_PTR, r._CLASS_IN)) - - for i in range(30): - out.add_answer_at_time( - DNSText( - ("HASS Bridge W9DN %s._hap._tcp.local." % i), - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - packets = out.packets() - assert len(packets) == 3 - - first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & r._FLAGS_TC == r._FLAGS_TC - assert first_packet.valid is True - - second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & r._FLAGS_TC == r._FLAGS_TC - assert second_packet.valid is True - - third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & r._FLAGS_TC == 0 - assert third_packet.valid is True - - -def test_tc_bit_not_set_in_answer_packet(): - """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) - for i in range(30): - out.add_answer_at_time( - DNSText( - ("HASS Bridge W9DN %s._hap._tcp.local." % i), - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - packets = out.packets() - assert len(packets) == 3 - - first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & r._FLAGS_TC == 0 - assert first_packet.valid is True - - second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & r._FLAGS_TC == 0 - assert second_packet.valid is True - - third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & r._FLAGS_TC == 0 - assert third_packet.valid is True - - @pytest.mark.parametrize( "errno,expected_result", [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], From eb37f089579fdc5a405dbc2f0ce5620cf9d1b011 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:25:45 -1000 Subject: [PATCH 250/608] Move additional tests to test_core (#559) --- tests/__init__.py | 35 +++++++++ tests/test_core.py | 163 +++++++++++++++++++++++++++++++++++++- tests/test_init.py | 191 +-------------------------------------------- 3 files changed, 197 insertions(+), 192 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index f924adf2..6399dbef 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -20,6 +20,13 @@ USA """ +import socket +from functools import lru_cache + + +import ifaddr + + from zeroconf.core import Zeroconf from zeroconf.dns import DNSIncoming @@ -27,3 +34,31 @@ def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: """Inject a DNSIncoming response.""" zc.handle_response(msg) + + +@lru_cache(maxsize=None) +def has_working_ipv6(): + """Return True if if the system can bind an IPv6 address.""" + if not socket.has_ipv6: + return False + + try: + sock = socket.socket(socket.AF_INET6) + sock.bind(('::1', 0)) + except Exception: + return False + finally: + if sock: + sock.close() + + for iface in ifaddr.get_adapters(): + for addr in iface.ips: + if addr.is_IPv6 and iface.index is not None: + return True + return False + + +def _clear_cache(zc): + for name in zc.cache.names(): + for record in zc.cache.entries_with_name(name): + zc.cache.remove(record) diff --git a/tests/test_core.py b/tests/test_core.py index 40c993b1..0d63a58c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,16 +6,18 @@ import itertools import logging -import threading +import os +import socket import time import unittest import unittest.mock +from typing import cast - -import pytest import zeroconf as r from zeroconf import core +from . import has_working_ipv6, _inject_response + log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -52,3 +54,158 @@ def test_reaper(self): assert entries_with_cache != original_entries assert record_with_10s_ttl in entries assert record_with_1s_ttl not in entries + + +class Framework(unittest.TestCase): + def test_launch_and_close(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) + rv.close() + + def test_launch_and_close_context_manager(self): + with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv: + assert rv.done is False + assert rv.done is True + + with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv: + assert rv.done is False + assert rv.done is True + + def test_launch_and_close_unicast(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True) + rv.close() + + def test_close_multiple_times(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) + rv.close() + rv.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_launch_and_close_v4_v6(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All) + rv.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_launch_and_close_v6_only(self): + rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) + rv.close() + rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) + rv.close() + + def test_handle_response(self): + def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + ttl = 120 + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + + if service_state_change == r.ServiceStateChange.Updated: + generated.add_answer_at_time( + r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 + ) + return r.DNSIncoming(generated.packet()) + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + + generated.add_answer_at_time( + r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + ) + generated.add_answer_at_time( + r.DNSService( + service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + ), + 0, + ) + generated.add_answer_at_time( + r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + + return r.DNSIncoming(generated.packet()) + + def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + """Mock an incoming message for the case where the packet is split.""" + ttl = 120 + generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + r._TYPE_A, + r._CLASS_IN | r._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSService( + service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + ), + 0, + ) + return r.DNSIncoming(generated.packet()) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-2.local.' + service_text = b'path=/~paulsm/' + service_address = '10.0.1.2' + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + try: + # service added + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) + dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' + all_dns_text = zeroconf.cache.get_all_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert [dns_text] == all_dns_text + + # https://tools.ietf.org/html/rfc6762#section-10.2 + # Instead of merging this new record additively into the cache in addition + # to any previous records with the same name, rrtype, and rrclass, + # all old records with that name, rrtype, and rrclass that were received + # more than one second ago are declared invalid, + # and marked to expire from the cache in one second. + time.sleep(1.1) + + # service updated. currently only text record can be updated + service_text = b'path=/~humingchun/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' + + time.sleep(1.1) + + # The split message only has a SRV and A record. + # This should not evict TXT records from the cache + _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated)) + time.sleep(1.1) + dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert dns_text is not None + assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' + + # service removed + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) + dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + assert dns_text is None + + finally: + zeroconf.close() diff --git a/tests/test_init.py b/tests/test_init.py index 159fd032..7ec3b658 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -11,24 +11,20 @@ import time import unittest import unittest.mock -from functools import lru_cache from threading import Event -from typing import Dict, Optional, cast # noqa # used in type hints - -import ifaddr +from typing import Dict, Optional # noqa # used in type hints import pytest import zeroconf as r from zeroconf import ( - DNSText, ServiceBrowser, ServiceInfo, Zeroconf, ZeroconfServiceTypes, ) -from . import _inject_response +from . import has_working_ipv6, _clear_cache, _inject_response log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -45,34 +41,6 @@ def teardown_module(): log.setLevel(original_logging_level) -@lru_cache(maxsize=None) -def has_working_ipv6(): - """Return True if if the system can bind an IPv6 address.""" - if not socket.has_ipv6: - return False - - try: - sock = socket.socket(socket.AF_INET6) - sock.bind(('::1', 0)) - except Exception: - return False - finally: - if sock: - sock.close() - - for iface in ifaddr.get_adapters(): - for addr in iface.ips: - if addr.is_IPv6 and iface.index is not None: - return True - return False - - -def _clear_cache(zc): - for name in zc.cache.names(): - for record in zc.cache.entries_with_name(name): - zc.cache.remove(record) - - class Names(unittest.TestCase): def test_long_name(self): generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) @@ -279,161 +247,6 @@ def generate_host(zc, host_name, type_): zc.send(out) -class Framework(unittest.TestCase): - def test_launch_and_close(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) - rv.close() - - def test_launch_and_close_context_manager(self): - with r.Zeroconf(interfaces=r.InterfaceChoice.All) as rv: - assert rv.done is False - assert rv.done is True - - with r.Zeroconf(interfaces=r.InterfaceChoice.Default) as rv: - assert rv.done is False - assert rv.done is True - - def test_launch_and_close_unicast(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, unicast=True) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, unicast=True) - rv.close() - - def test_close_multiple_times(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default) - rv.close() - rv.close() - - @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') - @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') - def test_launch_and_close_v4_v6(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.All) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.All) - rv.close() - - @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') - @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') - def test_launch_and_close_v6_only(self): - rv = r.Zeroconf(interfaces=r.InterfaceChoice.All, ip_version=r.IPVersion.V6Only) - rv.close() - rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) - rv.close() - - def test_handle_response(self): - def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - - if service_state_change == r.ServiceStateChange.Updated: - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - return r.DNSIncoming(generated.packet()) - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - - generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 - ) - generated.add_answer_at_time( - r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server - ), - 0, - ) - generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - - return r.DNSIncoming(generated.packet()) - - def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - """Mock an incoming message for the case where the packet is split.""" - ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - generated.add_answer_at_time( - r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server - ), - 0, - ) - return r.DNSIncoming(generated.packet()) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-2.local.' - service_text = b'path=/~paulsm/' - service_address = '10.0.1.2' - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - - try: - # service added - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is not None - assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' - all_dns_text = zeroconf.cache.get_all_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert [dns_text] == all_dns_text - - # https://tools.ietf.org/html/rfc6762#section-10.2 - # Instead of merging this new record additively into the cache in addition - # to any previous records with the same name, rrtype, and rrclass, - # all old records with that name, rrtype, and rrclass that were received - # more than one second ago are declared invalid, - # and marked to expire from the cache in one second. - time.sleep(1.1) - - # service updated. currently only text record can be updated - service_text = b'path=/~humingchun/' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is not None - assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' - - time.sleep(1.1) - - # The split message only has a SRV and A record. - # This should not evict TXT records from the cache - _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated)) - time.sleep(1.1) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is not None - assert cast(DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' - - # service removed - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) - assert dns_text is None - - finally: - zeroconf.close() - - class Exceptions(unittest.TestCase): browser = None # type: Zeroconf From b5d848de1ed95c55f8c262bcf0811248818da901 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:33:50 -1000 Subject: [PATCH 251/608] Move exceptions tests to test_exceptions (#560) --- tests/test_exceptions.py | 156 +++++++++++++++++++++++++++++++++++++++ tests/test_init.py | 126 ------------------------------- 2 files changed, 156 insertions(+), 126 deletions(-) create mode 100644 tests/test_exceptions.py diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..c85da045 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf.exceptions """ + +import logging +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import ( + ServiceInfo, + Zeroconf, +) + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class Exceptions(unittest.TestCase): + + browser = None # type: Zeroconf + + @classmethod + def setUpClass(cls): + cls.browser = Zeroconf(interfaces=['127.0.0.1']) + + @classmethod + def tearDownClass(cls): + cls.browser.close() + del cls.browser + + def test_bad_service_info_name(self): + self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not") + + def test_bad_service_names(self): + bad_names_to_try = ( + '', + 'local', + '_tcp.local.', + '_udp.local.', + '._udp.local.', + '_@._tcp.local.', + '_A@._tcp.local.', + '_x--x._tcp.local.', + '_-x._udp.local.', + '_x-._tcp.local.', + '_22._udp.local.', + '_2-2._tcp.local.', + '_1234567890-abcde._udp.local.', + '\x00._x._udp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name) + + def test_bad_local_names_for_get_service_info(self): + bad_names_to_try = ( + 'homekitdev._nothttp._tcp.local.', + 'homekitdev._http._udp.local.', + ) + for name in bad_names_to_try: + self.assertRaises( + r.BadTypeInNameException, self.browser.get_service_info, '_http._tcp.local.', name + ) + + def test_good_instance_names(self): + assert r.service_type_name('.._x._tcp.local.') == '_x._tcp.local.' + assert r.service_type_name('x.sub._http._tcp.local.') == '_http._tcp.local.' + assert ( + r.service_type_name('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') + == '_http._tcp.local.' + ) + + def test_good_instance_names_without_protocol(self): + good_names_to_try = ( + "Rachio-C73233.local.", + 'YeelightColorBulb-3AFD.local.', + 'YeelightTunableBulb-7220.local.', + "AlexanderHomeAssistant 74651D.local.", + 'iSmartGate-152.local.', + 'MyQ-FGA.local.', + 'lutron-02c4392a.local.', + 'WICED-hap-3E2734.local.', + 'MyHost.local.', + 'MyHost.sub.local.', + ) + for name in good_names_to_try: + assert r.service_type_name(name, strict=False) == 'local.' + + for name in good_names_to_try: + # Raises without strict=False + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_bad_types(self): + bad_names_to_try = ( + '._x._tcp.local.', + 'a' * 64 + '._sub._http._tcp.local.', + 'a' * 62 + u'â._sub._http._tcp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_bad_sub_types(self): + bad_names_to_try = ( + '_sub._http._tcp.local.', + '._sub._http._tcp.local.', + '\x7f._sub._http._tcp.local.', + '\x1f._sub._http._tcp.local.', + ) + for name in bad_names_to_try: + self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) + + def test_good_service_names(self): + good_names_to_try = ( + ('_x._tcp.local.', '_x._tcp.local.'), + ('_x._udp.local.', '_x._udp.local.'), + ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'), + ('x._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 61 + u'â._sub._http._tcp.local.', '_http._tcp.local.'), + ) + + for name, result in good_names_to_try: + assert r.service_type_name(name) == result + + assert r.service_type_name('_one_two._tcp.local.', strict=False) == '_one_two._tcp.local.' + + def test_invalid_addresses(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + bad = ('127.0.0.1', '::1', 42) + for addr in bad: + self.assertRaisesRegex( + TypeError, + 'Addresses must be bytes', + ServiceInfo, + type_, + registration_name, + port=80, + addresses=[addr], + ) diff --git a/tests/test_init.py b/tests/test_init.py index 7ec3b658..f64443f0 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -247,132 +247,6 @@ def generate_host(zc, host_name, type_): zc.send(out) -class Exceptions(unittest.TestCase): - - browser = None # type: Zeroconf - - @classmethod - def setUpClass(cls): - cls.browser = Zeroconf(interfaces=['127.0.0.1']) - - @classmethod - def tearDownClass(cls): - cls.browser.close() - del cls.browser - - def test_bad_service_info_name(self): - self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, "type", "type_not") - - def test_bad_service_names(self): - bad_names_to_try = ( - '', - 'local', - '_tcp.local.', - '_udp.local.', - '._udp.local.', - '_@._tcp.local.', - '_A@._tcp.local.', - '_x--x._tcp.local.', - '_-x._udp.local.', - '_x-._tcp.local.', - '_22._udp.local.', - '_2-2._tcp.local.', - '_1234567890-abcde._udp.local.', - '\x00._x._udp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, self.browser.get_service_info, name, 'x.' + name) - - def test_bad_local_names_for_get_service_info(self): - bad_names_to_try = ( - 'homekitdev._nothttp._tcp.local.', - 'homekitdev._http._udp.local.', - ) - for name in bad_names_to_try: - self.assertRaises( - r.BadTypeInNameException, self.browser.get_service_info, '_http._tcp.local.', name - ) - - def test_good_instance_names(self): - assert r.service_type_name('.._x._tcp.local.') == '_x._tcp.local.' - assert r.service_type_name('x.sub._http._tcp.local.') == '_http._tcp.local.' - assert ( - r.service_type_name('6d86f882b90facee9170ad3439d72a4d6ee9f511._zget._http._tcp.local.') - == '_http._tcp.local.' - ) - - def test_good_instance_names_without_protocol(self): - good_names_to_try = ( - "Rachio-C73233.local.", - 'YeelightColorBulb-3AFD.local.', - 'YeelightTunableBulb-7220.local.', - "AlexanderHomeAssistant 74651D.local.", - 'iSmartGate-152.local.', - 'MyQ-FGA.local.', - 'lutron-02c4392a.local.', - 'WICED-hap-3E2734.local.', - 'MyHost.local.', - 'MyHost.sub.local.', - ) - for name in good_names_to_try: - assert r.service_type_name(name, strict=False) == 'local.' - - for name in good_names_to_try: - # Raises without strict=False - self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) - - def test_bad_types(self): - bad_names_to_try = ( - '._x._tcp.local.', - 'a' * 64 + '._sub._http._tcp.local.', - 'a' * 62 + u'â._sub._http._tcp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) - - def test_bad_sub_types(self): - bad_names_to_try = ( - '_sub._http._tcp.local.', - '._sub._http._tcp.local.', - '\x7f._sub._http._tcp.local.', - '\x1f._sub._http._tcp.local.', - ) - for name in bad_names_to_try: - self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) - - def test_good_service_names(self): - good_names_to_try = ( - ('_x._tcp.local.', '_x._tcp.local.'), - ('_x._udp.local.', '_x._udp.local.'), - ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'), - ('x._sub._http._tcp.local.', '_http._tcp.local.'), - ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'), - ('a' * 61 + u'â._sub._http._tcp.local.', '_http._tcp.local.'), - ) - - for name, result in good_names_to_try: - assert r.service_type_name(name) == result - - assert r.service_type_name('_one_two._tcp.local.', strict=False) == '_one_two._tcp.local.' - - def test_invalid_addresses(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - bad = ('127.0.0.1', '::1', 42) - for addr in bad: - self.assertRaisesRegex( - TypeError, - 'Addresses must be bytes', - ServiceInfo, - type_, - registration_name, - port=80, - addresses=[addr], - ) - - class TestDnsIncoming(unittest.TestCase): def test_incoming_exception_handling(self): generated = r.DNSOutgoing(0) From ae1ce092de7eb4797da0f56e9eb8e538c95a8cc1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:43:22 -1000 Subject: [PATCH 252/608] Move additional dns tests to test_dns (#561) --- tests/test_dns.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_init.py | 68 ---------------------------------------------- 2 files changed, 68 insertions(+), 68 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 8117339d..de0e4932 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -416,6 +416,74 @@ def test_numbers_questions(self): assert num_additionals == 0 +class TestDnsIncoming(unittest.TestCase): + def test_incoming_exception_handling(self): + generated = r.DNSOutgoing(0) + packet = generated.packet() + packet = packet[:8] + b'deadbeef' + packet[8:] + parsed = r.DNSIncoming(packet) + parsed = r.DNSIncoming(packet) + assert parsed.valid is False + + def test_incoming_unknown_type(self): + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + generated.add_additional_answer(answer) + packet = generated.packet() + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + assert parsed.is_query() != parsed.is_response() + + def test_incoming_ipv6(self): + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + packed = socket.inet_pton(socket.AF_INET6, addr) + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN | r._CLASS_UNIQUE, 1, packed) + generated.add_additional_answer(answer) + packet = generated.packet() + parsed = r.DNSIncoming(packet) + record = parsed.answers[0] + assert isinstance(record, r.DNSAddress) + assert record.address == packed + + +class TestDNSCache(unittest.TestCase): + def test_order(self): + record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN) + cached_record = cache.get(entry) + assert cached_record == record2 + + def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): + record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache + + def test_cache_empty_multiple_calls_does_not_throw(self): + record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + # Ensure multiple removes does not throw + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache + + def test_dns_compression_rollback_for_corruption(): """Verify rolling back does not lead to dns compression corruption.""" out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) diff --git a/tests/test_init.py b/tests/test_init.py index f64443f0..1ca4f0c1 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -247,37 +247,6 @@ def generate_host(zc, host_name, type_): zc.send(out) -class TestDnsIncoming(unittest.TestCase): - def test_incoming_exception_handling(self): - generated = r.DNSOutgoing(0) - packet = generated.packet() - packet = packet[:8] + b'deadbeef' + packet[8:] - parsed = r.DNSIncoming(packet) - parsed = r.DNSIncoming(packet) - assert parsed.valid is False - - def test_incoming_unknown_type(self): - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - generated.add_additional_answer(answer) - packet = generated.packet() - parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 - assert parsed.is_query() != parsed.is_response() - - def test_incoming_ipv6(self): - addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com - packed = socket.inet_pton(socket.AF_INET6, addr) - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN | r._CLASS_UNIQUE, 1, packed) - generated.add_additional_answer(answer) - packet = generated.packet() - parsed = r.DNSIncoming(packet) - record = parsed.answers[0] - assert isinstance(record, r.DNSAddress) - assert record.address == packed - - class TestRegistrar(unittest.TestCase): def test_ttl(self): @@ -484,43 +453,6 @@ def test_lookups(self): assert registry.get_types() == [type_] -class TestDNSCache(unittest.TestCase): - def test_order(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN) - cached_record = cache.get(entry) - assert cached_record == record2 - - def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - assert 'a' not in cache.cache - - def test_cache_empty_multiple_calls_does_not_throw(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - # Ensure multiple removes does not throw - cache.remove(record1) - cache.remove(record2) - assert 'a' not in cache.cache - - class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): From 7807fa0dfdab20d950c446f17b7233a8c65cbab1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 17:46:38 -1000 Subject: [PATCH 253/608] Update setup.py for utils and services (#562) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ac24ca7f..c1c0da34 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ author='Paul Scott-Murphy, William McBrine, Jakub Stasiak', url='https://github.com/jstasiak/python-zeroconf', package_data={"zeroconf": ["py.typed"]}, - packages=["zeroconf"], + packages=["zeroconf", "zeroconf.services", "zeroconf.utils"], platforms=['unix', 'linux', 'osx'], license='LGPL', zip_safe=False, From a8420cde192647486eba4da4e54df9d0fe65adba Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 23:28:23 -1000 Subject: [PATCH 254/608] Removed protected imports from zeroconf namespace (#567) - These protected items are not intended to be part of the public API --- tests/test_aio.py | 15 +-- tests/test_core.py | 61 ++++++++---- tests/test_dns.py | 211 +++++++++++++++++++++-------------------- tests/test_init.py | 130 ++++++++++++++----------- tests/test_services.py | 113 ++++++++++++---------- zeroconf/__init__.py | 54 ----------- 6 files changed, 291 insertions(+), 293 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 48a6ccc4..b1be151d 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -11,17 +11,12 @@ import pytest -from zeroconf import ( - BadTypeInNameException, - NonUniqueNameException, - ServiceInfo, - ServiceListener, - ServiceNameAlreadyRegistered, - Zeroconf, - _LISTENER_TIME, - current_time_millis, -) from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf +from zeroconf.core import Zeroconf +from zeroconf.const import _LISTENER_TIME +from zeroconf.exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered +from zeroconf.services import ServiceInfo, ServiceListener +from zeroconf.utils.time import current_time_millis @pytest.fixture(autouse=True) diff --git a/tests/test_core.py b/tests/test_core.py index 0d63a58c..0d2a2a06 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -15,6 +15,7 @@ import zeroconf as r from zeroconf import core +from zeroconf import const from . import has_working_ipv6, _inject_response @@ -39,8 +40,8 @@ def test_reaper(self): zeroconf = core.Zeroconf(interfaces=['127.0.0.1']) cache = zeroconf.cache original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - record_with_10s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 10, b'a') - record_with_1s_ttl = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') zeroconf.cache.add(record_with_10s_ttl) zeroconf.cache.add(record_with_1s_ttl) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) @@ -102,11 +103,18 @@ def test_launch_and_close_v6_only(self): def test_handle_response(self): def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) if service_state_change == r.ServiceStateChange.Updated: generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ), + 0, ) return r.DNSIncoming(generated.packet()) @@ -114,22 +122,32 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ttl = 0 generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) generated.add_answer_at_time( r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, ), 0, ) generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, ) generated.add_answer_at_time( r.DNSAddress( service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), @@ -141,12 +159,12 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: """Mock an incoming message for the case where the packet is split.""" ttl = 120 - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSAddress( service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), @@ -154,7 +172,14 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS ) generated.add_answer_at_time( r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, ), 0, ) @@ -171,10 +196,10 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS try: # service added _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) assert dns_text is not None assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~paulsm/' - all_dns_text = zeroconf.cache.get_all_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + all_dns_text = zeroconf.cache.get_all_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) assert [dns_text] == all_dns_text # https://tools.ietf.org/html/rfc6762#section-10.2 @@ -188,7 +213,7 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS # service updated. currently only text record can be updated service_text = b'path=/~humingchun/' _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) assert dns_text is not None assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' @@ -198,13 +223,13 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS # This should not evict TXT records from the cache _inject_response(zeroconf, mock_split_incoming_msg(r.ServiceStateChange.Updated)) time.sleep(1.1) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) assert dns_text is not None assert cast(r.DNSText, dns_text).text == service_text # service_text is b'path=/~humingchun/' # service removed _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) - dns_text = zeroconf.cache.get_by_details(service_name, r._TYPE_TXT, r._CLASS_IN) + dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) assert dns_text is None finally: diff --git a/tests/test_dns.py b/tests/test_dns.py index de0e4932..db41693d 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -14,6 +14,7 @@ from typing import Dict, cast # noqa # used in type hints import zeroconf as r +from zeroconf import const from zeroconf import ( DNSHinfo, DNSText, @@ -46,46 +47,48 @@ def test_dns_text_repr(self): repr(text) def test_dns_hinfo_repr_eq(self): - hinfo = DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os') + hinfo = DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os') assert hinfo == hinfo repr(hinfo) def test_dns_pointer_repr(self): - pointer = r.DNSPointer('irrelevant', r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, '123') + pointer = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') repr(pointer) def test_dns_address_repr(self): - address = r.DNSAddress('irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + address = r.DNSAddress('irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, b'a') assert repr(address).endswith("b'a'") address_ipv4 = r.DNSAddress( - 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') + 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET, '127.0.0.1') ) assert repr(address_ipv4).endswith('127.0.0.1') address_ipv6 = r.DNSAddress( - 'irrelevant', r._TYPE_SOA, r._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') + 'irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, socket.inet_pton(socket.AF_INET6, '::1') ) assert repr(address_ipv6).endswith('::1') def test_dns_question_repr(self): - question = r.DNSQuestion('irrelevant', r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE) + question = r.DNSQuestion('irrelevant', const._TYPE_SRV, const._CLASS_IN | const._CLASS_UNIQUE) repr(question) assert not question != question def test_dns_service_repr(self): - service = r.DNSService('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL, 0, 0, 80, 'a') + service = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a' + ) repr(service) def test_dns_record_abc(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) self.assertRaises(r.AbstractMethodException, record.__eq__, record) self.assertRaises(r.AbstractMethodException, record.write, None) def test_dns_record_reset_ttl(self): - record = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) time.sleep(1) - record2 = r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + record2 = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) now = r.current_time_millis() assert record.created != record2.created @@ -131,7 +134,7 @@ def test_service_info_text_properties_not_given(self): repr(info) def test_dns_outgoing_repr(self): - dns_outgoing = r.DNSOutgoing(r._FLAGS_QR_QUERY) + dns_outgoing = r.DNSOutgoing(const._FLAGS_QR_QUERY) repr(dns_outgoing) @@ -145,22 +148,22 @@ def test_parse_own_packet_simple_unicast(self): r.DNSIncoming(generated.packet()) def test_parse_own_packet_flags(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) r.DNSIncoming(generated.packet()) def test_parse_own_packet_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - generated.add_question(r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN)) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) r.DNSIncoming(generated.packet()) def test_parse_own_packet_response(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( r.DNSService( "æøå.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, 80, @@ -173,8 +176,8 @@ def test_parse_own_packet_response(self): assert len(generated.answers) == len(parsed.answers) def test_match_question(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) parsed = r.DNSIncoming(generated.packet()) assert len(generated.questions) == 1 @@ -182,14 +185,14 @@ def test_match_question(self): assert question == parsed.questions[0] def test_suppress_answer(self): - query_generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + query_generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) query_generated.add_question(question) answer1 = r.DNSService( "testname1.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, 80, @@ -197,9 +200,9 @@ def test_suppress_answer(self): ) staleanswer2 = r.DNSService( "testname2.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL / 2, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL / 2, 0, 0, 80, @@ -207,9 +210,9 @@ def test_suppress_answer(self): ) answer2 = r.DNSService( "testname2.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, 80, @@ -220,7 +223,7 @@ def test_suppress_answer(self): query = r.DNSIncoming(query_generated.packet()) # Should be suppressed - response = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) response.add_answer(query, answer1) assert len(response.answers) == 0 @@ -237,13 +240,13 @@ def test_suppress_answer(self): # Should not be suppressed, type is different tmp = copy.copy(answer1) - tmp.type = r._TYPE_A + tmp.type = const._TYPE_A response.add_answer(query, tmp) assert len(response.answers) == 3 # Should not be suppressed, class is different tmp = copy.copy(answer1) - tmp.class_ = r._CLASS_NONE + tmp.class_ = const._CLASS_NONE response.add_answer(query, tmp) assert len(response.answers) == 4 @@ -251,30 +254,30 @@ def test_suppress_answer(self): def test_dns_hinfo(self): generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'os')) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packet()) answer = cast(r.DNSHinfo, parsed.answers[0]) assert answer.cpu == u'cpu' assert answer.os == u'os' generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', r._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) self.assertRaises(r.NamePartTooLongException, generated.packet) def test_many_questions(self): """Test many questions get seperated into multiple packets.""" - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) questions = [] for i in range(100): - question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) questions.append(question) assert len(generated.questions) == 100 packets = generated.packets() assert len(packets) == 2 - assert len(packets[0]) < r._MAX_MSG_TYPICAL - assert len(packets[1]) < r._MAX_MSG_TYPICAL + assert len(packets[0]) < const._MAX_MSG_TYPICAL + assert len(packets[1]) < const._MAX_MSG_TYPICAL parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 85 @@ -286,15 +289,15 @@ def test_only_one_answer_can_by_large(self): https://datatracker.ietf.org/doc/html/rfc6762#section-17 """ - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - query = r.DNSIncoming(r.DNSOutgoing(r._FLAGS_QR_QUERY).packet()) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packet()) for i in range(3): generated.add_answer( query, r.DNSText( "zoom._hap._tcp.local.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, 1200, b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, ), @@ -303,9 +306,9 @@ def test_only_one_answer_can_by_large(self): query, r.DNSService( "testname1.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, 80, @@ -316,16 +319,16 @@ def test_only_one_answer_can_by_large(self): packets = generated.packets() assert len(packets) == 4 - assert len(packets[0]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[0]) > r._MAX_MSG_TYPICAL + assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[0]) > const._MAX_MSG_TYPICAL - assert len(packets[1]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[1]) > r._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[1]) > const._MAX_MSG_TYPICAL - assert len(packets[2]) <= r._MAX_MSG_ABSOLUTE - assert len(packets[2]) > r._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[2]) > const._MAX_MSG_TYPICAL - assert len(packets[3]) <= r._MAX_MSG_TYPICAL + assert len(packets[3]) <= const._MAX_MSG_TYPICAL for packet in packets: parsed = r.DNSIncoming(packet) @@ -341,15 +344,15 @@ def test_questions_do_not_end_up_every_packet(self): questions and as many more Known-Answer records as will fit. """ - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) for i in range(35): - question = r.DNSQuestion(f"testname{i}.local.", r._TYPE_SRV, r._CLASS_IN) + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) answer = r.DNSService( f"testname{i}.local.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, 80, @@ -362,8 +365,8 @@ def test_questions_do_not_end_up_every_packet(self): packets = generated.packets() assert len(packets) == 2 - assert len(packets[0]) <= r._MAX_MSG_TYPICAL - assert len(packets[1]) <= r._MAX_MSG_TYPICAL + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 35 @@ -377,25 +380,25 @@ def test_questions_do_not_end_up_every_packet(self): class PacketForm(unittest.TestCase): def test_transaction_id(self): """ID must be zero in a DNS-SD packet""" - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) bytes = generated.packet() id = bytes[0] << 8 | bytes[1] assert id == 0 def test_query_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_QUERY) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) bytes = generated.packet() flags = bytes[2] << 8 | bytes[3] assert flags == 0x0 def test_response_header_bits(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) bytes = generated.packet() flags = bytes[2] << 8 | bytes[3] assert flags == 0x8000 def test_numbers(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) bytes = generated.packet() (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) assert num_questions == 0 @@ -404,8 +407,8 @@ def test_numbers(self): assert num_additionals == 0 def test_numbers_questions(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion("testname.local.", r._TYPE_SRV, r._CLASS_IN) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) for i in range(10): generated.add_question(question) bytes = generated.packet() @@ -427,7 +430,7 @@ def test_incoming_exception_handling(self): def test_incoming_unknown_type(self): generated = r.DNSOutgoing(0) - answer = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') + answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) @@ -438,7 +441,7 @@ def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) generated = r.DNSOutgoing(0) - answer = r.DNSAddress('domain', r._TYPE_AAAA, r._CLASS_IN | r._CLASS_UNIQUE, 1, packed) + answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) generated.add_additional_answer(answer) packet = generated.packet() parsed = r.DNSIncoming(packet) @@ -449,18 +452,18 @@ def test_incoming_ipv6(self): class TestDNSCache(unittest.TestCase): def test_order(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.add(record1) cache.add(record2) - entry = r.DNSEntry('a', r._TYPE_SOA, r._CLASS_IN) + entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) cached_record = cache.get(entry) assert cached_record == record2 def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.add(record1) cache.add(record2) @@ -470,8 +473,8 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): assert 'a' not in cache.cache def test_cache_empty_multiple_calls_does_not_throw(self): - record1 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', r._TYPE_SOA, r._CLASS_IN, 1, b'b') + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.add(record1) cache.add(record2) @@ -486,7 +489,7 @@ def test_cache_empty_multiple_calls_does_not_throw(self): def test_dns_compression_rollback_for_corruption(): """Verify rolling back does not lead to dns compression corruption.""" - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) address = socket.inet_pton(socket.AF_INET, "192.168.208.5") additionals = [ @@ -589,9 +592,9 @@ def test_dns_compression_rollback_for_corruption(): out.add_answer_at_time( DNSText( "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), @@ -602,9 +605,9 @@ def test_dns_compression_rollback_for_corruption(): out.add_additional_answer( r.DNSService( record["name"], # type: ignore - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, 0, 0, record["port"], # type: ignore @@ -614,18 +617,18 @@ def test_dns_compression_rollback_for_corruption(): out.add_additional_answer( r.DNSText( record["name"], # type: ignore - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, record["text"], # type: ignore ) ) out.add_additional_answer( r.DNSAddress( record["name"], # type: ignore - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_HOST_TTL, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, record["address"], # type: ignore ) ) @@ -639,17 +642,17 @@ def test_dns_compression_rollback_for_corruption(): def test_tc_bit_in_query_packet(): """Verify the TC bit is set when known answers exceed the packet size.""" - out = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) type_ = "_hap._tcp.local." - out.add_question(r.DNSQuestion(type_, r._TYPE_PTR, r._CLASS_IN)) + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), @@ -660,28 +663,28 @@ def test_tc_bit_in_query_packet(): assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert first_packet.flags & const._FLAGS_TC == const._FLAGS_TC assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & r._FLAGS_TC == r._FLAGS_TC + assert second_packet.flags & const._FLAGS_TC == const._FLAGS_TC assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.flags & const._FLAGS_TC == 0 assert third_packet.valid is True def test_tc_bit_not_set_in_answer_packet(): """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) for i in range(30): out.add_answer_at_time( DNSText( ("HASS Bridge W9DN %s._hap._tcp.local." % i), - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, - r._DNS_OTHER_TTL, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), @@ -692,13 +695,13 @@ def test_tc_bit_not_set_in_answer_packet(): assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & r._FLAGS_TC == 0 + assert first_packet.flags & const._FLAGS_TC == 0 assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & r._FLAGS_TC == 0 + assert second_packet.flags & const._FLAGS_TC == 0 assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & r._FLAGS_TC == 0 + assert third_packet.flags & const._FLAGS_TC == 0 assert third_packet.valid is True diff --git a/tests/test_init.py b/tests/test_init.py index 1ca4f0c1..f89f786e 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -17,12 +17,7 @@ import pytest import zeroconf as r -from zeroconf import ( - ServiceBrowser, - ServiceInfo, - Zeroconf, - ZeroconfServiceTypes, -) +from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, ZeroconfServiceTypes, const from . import has_working_ipv6, _clear_cache, _inject_response @@ -43,38 +38,38 @@ def teardown_module(): class Names(unittest.TestCase): def test_long_name(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) question = r.DNSQuestion( - "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", r._TYPE_SRV, r._CLASS_IN + "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN ) generated.add_question(question) r.DNSIncoming(generated.packet()) def test_exceedingly_long_name(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) name = "%slocal." % ("part." * 1000) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) r.DNSIncoming(generated.packet()) def test_extra_exceedingly_long_name(self): - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) name = "%slocal." % ("part." * 4000) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) r.DNSIncoming(generated.packet()) def test_exceedingly_long_name_part(self): name = "%s.local." % ("a" * 1000) - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) self.assertRaises(r.NamePartTooLongException, generated.packet) def test_same_name(self): name = "paired.local." - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) - question = r.DNSQuestion(name, r._TYPE_SRV, r._CLASS_IN) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) generated.add_question(question) r.DNSIncoming(generated.packet()) @@ -99,7 +94,7 @@ def test_lots_of_names(self): longest_packet_len = 0 longest_packet = None # type: Optional[r.DNSOutgoing] - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" for packet in out.packets(): nonlocal longest_packet_len, longest_packet @@ -123,7 +118,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # we will never get to this large of a packet given the application-layer # splitting of packets, but we still want to track the longest_packet_len # for the debug message below - while sleep_count < 100 and longest_packet_len < r._MAX_MSG_ABSOLUTE - 100: + while sleep_count < 100 and longest_packet_len < const._MAX_MSG_ABSOLUTE - 100: sleep_count += 1 time.sleep(0.1) @@ -135,8 +130,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) # now the browser has sent at least one request, verify the size - assert longest_packet_len <= r._MAX_MSG_TYPICAL - assert longest_packet_len >= r._MAX_MSG_TYPICAL - 100 + assert longest_packet_len <= const._MAX_MSG_TYPICAL + assert longest_packet_len >= const._MAX_MSG_TYPICAL - 100 # mock zeroconf's logger warning() and debug() from unittest.mock import patch @@ -167,8 +162,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # mock the zeroconf logger and check for the correct logging backoff call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count # force receive on oversized packet - s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT)) - s.sendto(packet, 0, (r._MDNS_ADDR, r._MDNS_PORT)) + s.sendto(packet, 0, (const._MDNS_ADDR, const._MDNS_PORT)) + s.sendto(packet, 0, (const._MDNS_ADDR, const._MDNS_PORT)) time.sleep(2.0) zeroconf.log.debug( 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts @@ -238,10 +233,21 @@ def generate_many_hosts(self, zc, type_, name, number_hosts): @staticmethod def generate_host(zc, host_name, type_): name = '.'.join((host_name, type_)) - out = r.DNSOutgoing(r._FLAGS_QR_RESPONSE | r._FLAGS_AA) - out.add_answer_at_time(r.DNSPointer(type_, r._TYPE_PTR, r._CLASS_IN, r._DNS_OTHER_TTL, name), 0) + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + out.add_answer_at_time( + r.DNSPointer(type_, const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, name), 0 + ) out.add_answer_at_time( - r.DNSService(type_, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, r._DNS_HOST_TTL, 0, 0, 80, name), + r.DNSService( + type_, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + name, + ), 0, ) zc.send(out) @@ -275,10 +281,10 @@ def test_ttl(self): def get_ttl(record_type): if expected_ttl is not None: return expected_ttl - elif record_type in [r._TYPE_A, r._TYPE_SRV]: - return r._DNS_HOST_TTL + elif record_type in [const._TYPE_A, const._TYPE_SRV]: + return const._DNS_HOST_TTL else: - return r._DNS_OTHER_TTL + return const._DNS_OTHER_TTL def _process_outgoing_packet(out): """Sends an outgoing packet.""" @@ -305,12 +311,12 @@ def _process_outgoing_packet(out): nbr_answers = nbr_additionals = nbr_authorities = 0 # query - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) assert query.is_query() is True - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -328,8 +334,8 @@ def _process_outgoing_packet(out): _process_outgoing_packet(zc.generate_service_query(info)) zc.registry.add(info) # register service with custom TTL - expected_ttl = r._DNS_HOST_TTL * 2 - assert expected_ttl != r._DNS_HOST_TTL + expected_ttl = const._DNS_HOST_TTL * 2 + assert expected_ttl != const._DNS_HOST_TTL for _ in range(3): _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 @@ -337,11 +343,11 @@ def _process_outgoing_packet(out): # query expected_ttl = None - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_SRV, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, r._TYPE_TXT, r._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, r._TYPE_A, r._CLASS_IN)) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -405,8 +411,8 @@ def test_register_and_lookup_type_by_uppercase_name(self): info.load_from_cache(zc) assert info.addresses == [] - out = r.DNSOutgoing(r._FLAGS_QR_QUERY) - out.add_question(r.DNSQuestion(type_.upper(), r._TYPE_PTR, r._CLASS_IN)) + out = r.DNSOutgoing(const._FLAGS_QR_QUERY) + out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) zc.send(out) time.sleep(0.5) info = ServiceInfo(type_, registration_name) @@ -786,7 +792,7 @@ def update_service(self, zc, type_, name) -> None: def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) assert generated.is_response() is True if service_state_change == r.ServiceStateChange.Removed: @@ -795,12 +801,22 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ttl = 120 generated.add_answer_at_time( - r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text), 0 + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, ) generated.add_answer_at_time( r.DNSService( - service_name, r._TYPE_SRV, r._CLASS_IN | r._CLASS_UNIQUE, ttl, 0, 0, 80, service_server + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, ), 0, ) @@ -812,8 +828,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi generated.add_answer_at_time( r.DNSAddress( service_server, - r._TYPE_AAAA, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_pton(socket.AF_INET6, service_v6_address), ), @@ -822,8 +838,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi generated.add_answer_at_time( r.DNSAddress( service_server, - r._TYPE_AAAA, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_pton(socket.AF_INET6, service_v6_second_address), ), @@ -832,8 +848,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi generated.add_answer_at_time( r.DNSAddress( service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_aton(service_address), ), @@ -841,7 +857,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ) generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) return r.DNSIncoming(generated.packet()) @@ -1003,19 +1019,19 @@ def test_ptr_optimization(): nbr_answers = nbr_additionals = nbr_authorities = 0 # query - query = r.DNSOutgoing(r._FLAGS_QR_QUERY | r._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, r._TYPE_PTR, r._CLASS_IN)) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) out = zc.query_handler.response(r.DNSIncoming(query.packet()), False) assert out is not None nbr_answers += len(out.answers) nbr_authorities += len(out.authorities) for answer in out.additionals: nbr_additionals += 1 - if answer.type == r._TYPE_SRV: + if answer.type == const._TYPE_SRV: has_srv = True - elif answer.type == r._TYPE_TXT: + elif answer.type == const._TYPE_TXT: has_txt = True - elif answer.type == r._TYPE_A: + elif answer.type == const._TYPE_A: has_a = True assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 assert has_srv and has_txt and has_a diff --git a/tests/test_services.py b/tests/test_services.py index 7cf476b4..07554eb6 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -14,6 +14,7 @@ import pytest import zeroconf as r +from zeroconf import const import zeroconf.services as s from zeroconf.core import Zeroconf from zeroconf.services import ( @@ -75,8 +76,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSText( service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), @@ -87,8 +88,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSService( service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, @@ -104,8 +105,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSAddress( 'ASH-2.local.', - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, new_address, ), @@ -117,8 +118,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSText( "incorrect.name.", - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', ), @@ -129,8 +130,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSService( "incorrect.name.", - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, @@ -146,8 +147,8 @@ def test_service_info_rejects_non_matching_updates(self): now, r.DNSAddress( "incorrect.name.", - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, new_address, ), @@ -174,8 +175,8 @@ def test_service_info_rejects_expired_records(self): now, r.DNSText( service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', ), @@ -184,8 +185,8 @@ def test_service_info_rejects_expired_records(self): # Expired record expired_record = r.DNSText( service_name, - r._TYPE_TXT, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', ) @@ -211,7 +212,7 @@ def test_get_info_partial(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" nonlocal last_sent @@ -223,7 +224,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) for record in records: generated.add_answer_at_time(record, 0) @@ -247,10 +248,10 @@ def get_service_info_helper(zc, type, name): send_event.wait(wait_time) assert last_sent is not None assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions assert service_info is None # Expext query for SRV, A, AAAA @@ -259,15 +260,23 @@ def get_service_info_helper(zc, type, name): _inject_response( zc, mock_incoming_msg( - [r.DNSText(service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text)] + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ) + ] ), ) send_event.wait(wait_time) assert last_sent is not None assert len(last_sent.questions) == 3 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions assert service_info is None # Expext query for A, AAAA @@ -279,8 +288,8 @@ def get_service_info_helper(zc, type, name): [ r.DNSService( service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, @@ -293,8 +302,8 @@ def get_service_info_helper(zc, type, name): send_event.wait(wait_time) assert last_sent is not None assert len(last_sent.questions) == 2 - assert r.DNSQuestion(service_server, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_server, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions last_sent = None assert service_info is None @@ -307,8 +316,8 @@ def get_service_info_helper(zc, type, name): [ r.DNSAddress( service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_pton(socket.AF_INET, service_address), ) @@ -340,7 +349,7 @@ def test_get_info_single(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" nonlocal last_sent @@ -352,7 +361,7 @@ def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) for record in records: generated.add_answer_at_time(record, 0) @@ -376,10 +385,10 @@ def get_service_info_helper(zc, type, name): send_event.wait(wait_time) assert last_sent is not None assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, r._TYPE_SRV, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_TXT, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_A, r._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, r._TYPE_AAAA, r._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions assert service_info is None # Expext no further queries @@ -390,12 +399,16 @@ def get_service_info_helper(zc, type, name): mock_incoming_msg( [ r.DNSText( - service_name, r._TYPE_TXT, r._CLASS_IN | r._CLASS_UNIQUE, ttl, service_text + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, ), r.DNSService( service_name, - r._TYPE_SRV, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, 0, 0, @@ -404,8 +417,8 @@ def get_service_info_helper(zc, type, name): ), r.DNSAddress( service_server, - r._TYPE_A, - r._CLASS_IN | r._CLASS_UNIQUE, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_pton(socket.AF_INET, service_address), ), @@ -449,9 +462,9 @@ def remove_service(self, zc, type_, name) -> None: def mock_incoming_msg( service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int ) -> r.DNSIncoming: - generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( - r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0 + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) return r.DNSIncoming(generated.packet()) @@ -476,7 +489,7 @@ def mock_incoming_msg( def _mock_get_expiration_time(self, percent): nonlocal called_with_refresh_time_check - if percent == r._EXPIRE_REFRESH_TIME_PERCENT: + if percent == const._EXPIRE_REFRESH_TIME_PERCENT: called_with_refresh_time_check = True return 0 return self.created + (percent * self.ttl * 10) @@ -541,7 +554,7 @@ def current_time_millis(): """Current system time in milliseconds""" return start_time + time_offset * 1000 - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" got_query.set() old_send(out, addr=addr, port=port) @@ -619,11 +632,11 @@ def current_time_millis(): """Current system time in milliseconds""" return time.time() * 1000 + time_offset * 1000 - expected_ttl = r._DNS_HOST_TTL + expected_ttl = const._DNS_HOST_TTL nbr_answers = 0 - def send(out, addr=r._MDNS_ADDR, port=r._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" pout = r.DNSIncoming(out.packet()) nonlocal nbr_answers @@ -693,7 +706,7 @@ def test_legacy_record_update_listener(): with pytest.raises(RuntimeError): r.RecordUpdateListener().update_record( - zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL) + zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) ) updates = [] diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 8242c4e1..043de013 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -22,57 +22,6 @@ import sys -from .const import ( # noqa # import needed for backwards compat - _BROWSER_BACKOFF_LIMIT, - _BROWSER_TIME, - _CACHE_CLEANUP_INTERVAL, - _CHECK_TIME, - _CLASSES, - _CLASS_IN, - _CLASS_NONE, - _CLASS_MASK, - _CLASS_UNIQUE, - _DNS_HOST_TTL, - _DNS_OTHER_TTL, - _DNS_PORT, - _EXPIRE_FULL_TIME_PERCENT, - _EXPIRE_REFRESH_TIME_PERCENT, - _EXPIRE_STALE_TIME_PERCENT, - _FLAGS_AA, - _FLAGS_QR_MASK, - _FLAGS_QR_QUERY, - _FLAGS_QR_RESPONSE, - _FLAGS_TC, - _HAS_ASCII_CONTROL_CHARS, - _HAS_A_TO_Z, - _HAS_ONLY_A_TO_Z_NUM_HYPHEN, - _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE, - _IPPROTO_IPV6, - _LISTENER_TIME, - _LOCAL_TRAILER, - _MAX_MSG_ABSOLUTE, - _MAX_MSG_TYPICAL, - _MDNS_ADDR, - _MDNS_ADDR6, - _MDNS_ADDR6_BYTES, - _MDNS_ADDR_BYTES, - _MDNS_PORT, - _NONTCP_PROTOCOL_LOCAL_TRAILER, - _REGISTER_TIME, - _SERVICE_TYPE_ENUMERATION_NAME, - _TCP_PROTOCOL_LOCAL_TRAILER, - _TYPES, - _TYPE_A, - _TYPE_AAAA, - _TYPE_ANY, - _TYPE_CNAME, - _TYPE_HINFO, - _TYPE_PTR, - _TYPE_SOA, - _TYPE_SRV, - _TYPE_TXT, - _UNREGISTER_TIME, -) from .core import NotifyListener, Zeroconf # noqa # import needed for backwards compat from .dns import ( # noqa # import needed for backwards compat DNSAddress, @@ -102,7 +51,6 @@ Signal, SignalRegistrationInterface, RecordUpdateListener, - _ServiceBrowserBase, ServiceBrowser, ServiceInfo, ServiceListener, @@ -120,8 +68,6 @@ InterfaceChoice, InterfacesType, IPVersion, - _is_v6_address, - _encode_address, get_all_addresses, ) from .utils.struct import int2byte # noqa # import needed for backwards compat From 0e0bc2a901ed1d64e357c63e9fb8655f3a6e9298 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 12 Jun 2021 23:49:35 -1000 Subject: [PATCH 255/608] Breakout DNSCache into zeroconf.cache (#568) --- zeroconf/__init__.py | 2 +- zeroconf/cache.py | 117 +++++++++++++++++++++++++++++++++++++++++++ zeroconf/core.py | 3 +- zeroconf/dns.py | 98 +++--------------------------------- 4 files changed, 126 insertions(+), 94 deletions(-) create mode 100644 zeroconf/cache.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 043de013..aae68e4e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -22,10 +22,10 @@ import sys +from .cache import DNSCache # noqa # import needed for backwards compat from .core import NotifyListener, Zeroconf # noqa # import needed for backwards compat from .dns import ( # noqa # import needed for backwards compat DNSAddress, - DNSCache, DNSEntry, DNSHinfo, DNSIncoming, diff --git a/zeroconf/cache.py b/zeroconf/cache.py new file mode 100644 index 00000000..48750f5a --- /dev/null +++ b/zeroconf/cache.py @@ -0,0 +1,117 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, Iterable, List, Optional, cast + +from .const import _TYPE_PTR +from .dns import DNSEntry, DNSPointer, DNSRecord, DNSService +from .utils.time import current_time_millis + + +class DNSCache: + """A cache of DNS entries.""" + + def __init__(self) -> None: + self.cache: Dict[str, List[DNSRecord]] = {} + self.service_cache: Dict[str, List[DNSRecord]] = {} + + def add(self, entry: DNSRecord) -> None: + """Adds an entry""" + # Insert last in list, get will return newest entry + # iteration will result in last update winning + self.cache.setdefault(entry.key, []).append(entry) + if isinstance(entry, DNSService): + self.service_cache.setdefault(entry.server, []).append(entry) + + def add_records(self, entries: Iterable[DNSRecord]) -> None: + """Add multiple records.""" + for entry in entries: + self.add(entry) + + def remove(self, entry: DNSRecord) -> None: + """Removes an entry.""" + if isinstance(entry, DNSService): + DNSCache.remove_key(self.service_cache, entry.server, entry) + DNSCache.remove_key(self.cache, entry.key, entry) + + def remove_records(self, entries: Iterable[DNSRecord]) -> None: + """Remove multiple records.""" + for entry in entries: + self.remove(entry) + + @staticmethod + def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: + """Forgiving remove of a cache key.""" + try: + cache[key].remove(entry) + if not cache[key]: + del cache[key] + except (KeyError, ValueError): + pass + + def get(self, entry: DNSEntry) -> Optional[DNSRecord]: + """Gets an entry by key. Will return None if there is no + matching entry.""" + for cached_entry in reversed(self.entries_with_name(entry.key)): + if entry.__eq__(cached_entry): + return cached_entry + return None + + def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: + """Gets the first matching entry by details. Returns None if no entries match.""" + return self.get(DNSEntry(name, type_, class_)) + + def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: + """Gets all matching entries by details.""" + match_entry = DNSEntry(name, type_, class_) + return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] + + def entries_with_server(self, server: str) -> List[DNSRecord]: + """Returns a list of entries whose server matches the name.""" + return self.service_cache.get(server, [])[:] + + def entries_with_name(self, name: str) -> List[DNSRecord]: + """Returns a list of entries whose key matches the name.""" + return self.cache.get(name.lower(), [])[:] + + def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: + now = current_time_millis() + for record in reversed(self.entries_with_name(name)): + if ( + record.type == _TYPE_PTR + and not record.is_expired(now) + and cast(DNSPointer, record).alias == alias + ): + return record + return None + + def names(self) -> List[str]: + """Return a copy of the list of current cache names.""" + return list(self.cache) + + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache.""" + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record diff --git a/zeroconf/core.py b/zeroconf/core.py index f432c411..23c3583e 100644 --- a/zeroconf/core.py +++ b/zeroconf/core.py @@ -28,6 +28,7 @@ from types import TracebackType # noqa # used in type hints from typing import Dict, List, Optional, Type, Union, cast +from .cache import DNSCache from .const import ( _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, @@ -44,7 +45,7 @@ _TYPE_PTR, _UNREGISTER_TIME, ) -from .dns import DNSCache, DNSIncoming, DNSOutgoing, DNSQuestion +from .dns import DNSIncoming, DNSOutgoing, DNSQuestion from .exceptions import NonUniqueNameException from .handlers import QueryHandler, RecordManager from .logger import QuietLogger, log diff --git a/zeroconf/dns.py b/zeroconf/dns.py index 60d3c919..c5139dac 100644 --- a/zeroconf/dns.py +++ b/zeroconf/dns.py @@ -23,7 +23,7 @@ import enum import socket import struct -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast from .const import ( _CLASSES, @@ -54,6 +54,11 @@ from .utils.time import current_time_millis, millis_to_seconds +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from .cache import DNSCache # pylint: disable=cyclic-import + + class DNSEntry: """A DNS entry""" @@ -937,94 +942,3 @@ def packets(self) -> List[bytes]: break self.state = self.State.finished return self.packets_data - - -class DNSCache: - - """A cache of DNS entries""" - - def __init__(self) -> None: - self.cache = {} # type: Dict[str, List[DNSRecord]] - self.service_cache = {} # type: Dict[str, List[DNSRecord]] - - def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert last in list, get will return newest entry - # iteration will result in last update winning - self.cache.setdefault(entry.key, []).append(entry) - if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server, []).append(entry) - - def add_records(self, entries: Iterable[DNSRecord]) -> None: - """Add multiple records.""" - for entry in entries: - self.add(entry) - - def remove(self, entry: DNSRecord) -> None: - """Removes an entry.""" - if isinstance(entry, DNSService): - DNSCache.remove_key(self.service_cache, entry.server, entry) - DNSCache.remove_key(self.cache, entry.key, entry) - - def remove_records(self, entries: Iterable[DNSRecord]) -> None: - """Remove multiple records.""" - for entry in entries: - self.remove(entry) - - @staticmethod - def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: - """Forgiving remove of a cache key.""" - try: - cache[key].remove(entry) - if not cache[key]: - del cache[key] - except (KeyError, ValueError): - pass - - def get(self, entry: DNSEntry) -> Optional[DNSRecord]: - """Gets an entry by key. Will return None if there is no - matching entry.""" - for cached_entry in reversed(self.entries_with_name(entry.key)): - if entry.__eq__(cached_entry): - return cached_entry - return None - - def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets the first matching entry by details. Returns None if no entries match.""" - return self.get(DNSEntry(name, type_, class_)) - - def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: - """Gets all matching entries by details.""" - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] - - def entries_with_server(self, server: str) -> List[DNSRecord]: - """Returns a list of entries whose server matches the name.""" - return self.service_cache.get(server, [])[:] - - def entries_with_name(self, name: str) -> List[DNSRecord]: - """Returns a list of entries whose key matches the name.""" - return self.cache.get(name.lower(), [])[:] - - def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: - now = current_time_millis() - for record in reversed(self.entries_with_name(name)): - if ( - record.type == _TYPE_PTR - and not record.is_expired(now) - and cast(DNSPointer, record).alias == alias - ): - return record - return None - - def names(self) -> List[str]: - """Return a copy of the list of current cache names.""" - return list(self.cache) - - def expire(self, now: float) -> Iterable[DNSRecord]: - """Purge expired entries from the cache.""" - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self.remove(record) - yield record From 1e7c07481bb0cd08fe492dab02be888c6a1dadf2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 00:07:35 -1000 Subject: [PATCH 256/608] Remove DNSOutgoing.packet backwards compatibility (#569) - DNSOutgoing.packet only returned a partial message when the DNSOutgoing contents exceeded _MAX_MSG_ABSOLUTE or _MAX_MSG_TYPICAL This was a legacy function that was replaced with .packets() which always returns a complete payload in #248 As packet() should not be used since it will end up missing data, it has been removed --- tests/test_core.py | 6 +++--- tests/test_dns.py | 36 ++++++++++++++++++------------------ tests/test_init.py | 20 ++++++++++---------- tests/test_services.py | 8 ++++---- zeroconf/dns.py | 15 --------------- 5 files changed, 35 insertions(+), 50 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 0d2a2a06..abdac3b9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -116,7 +116,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi ), 0, ) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) if service_state_change == r.ServiceStateChange.Removed: ttl = 0 @@ -154,7 +154,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi 0, ) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: """Mock an incoming message for the case where the packet is split.""" @@ -183,7 +183,7 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS ), 0, ) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) service_name = 'name._type._tcp.local.' service_type = '_type._tcp.local.' diff --git a/tests/test_dns.py b/tests/test_dns.py index db41693d..a3c50ee2 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -141,20 +141,20 @@ def test_dns_outgoing_repr(self): class PacketGeneration(unittest.TestCase): def test_parse_own_packet_simple(self): generated = r.DNSOutgoing(0) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_parse_own_packet_simple_unicast(self): generated = r.DNSOutgoing(0, False) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_parse_own_packet_flags(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_parse_own_packet_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_parse_own_packet_response(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) @@ -171,7 +171,7 @@ def test_parse_own_packet_response(self): ), 0, ) - parsed = r.DNSIncoming(generated.packet()) + parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.answers) == 1 assert len(generated.answers) == len(parsed.answers) @@ -179,7 +179,7 @@ def test_match_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) - parsed = r.DNSIncoming(generated.packet()) + parsed = r.DNSIncoming(generated.packets()[0]) assert len(generated.questions) == 1 assert len(generated.questions) == len(parsed.questions) assert question == parsed.questions[0] @@ -220,7 +220,7 @@ def test_suppress_answer(self): ) query_generated.add_answer_at_time(answer1, 0) query_generated.add_answer_at_time(staleanswer2, 0) - query = r.DNSIncoming(query_generated.packet()) + query = r.DNSIncoming(query_generated.packets()[0]) # Should be suppressed response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) @@ -255,14 +255,14 @@ def test_suppress_answer(self): def test_dns_hinfo(self): generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) - parsed = r.DNSIncoming(generated.packet()) + parsed = r.DNSIncoming(generated.packets()[0]) answer = cast(r.DNSHinfo, parsed.answers[0]) assert answer.cpu == u'cpu' assert answer.os == u'os' generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) - self.assertRaises(r.NamePartTooLongException, generated.packet) + self.assertRaises(r.NamePartTooLongException, generated.packets) def test_many_questions(self): """Test many questions get seperated into multiple packets.""" @@ -290,7 +290,7 @@ def test_only_one_answer_can_by_large(self): https://datatracker.ietf.org/doc/html/rfc6762#section-17 """ generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packet()) + query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) for i in range(3): generated.add_answer( query, @@ -381,25 +381,25 @@ class PacketForm(unittest.TestCase): def test_transaction_id(self): """ID must be zero in a DNS-SD packet""" generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - bytes = generated.packet() + bytes = generated.packets()[0] id = bytes[0] << 8 | bytes[1] assert id == 0 def test_query_header_bits(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - bytes = generated.packet() + bytes = generated.packets()[0] flags = bytes[2] << 8 | bytes[3] assert flags == 0x0 def test_response_header_bits(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - bytes = generated.packet() + bytes = generated.packets()[0] flags = bytes[2] << 8 | bytes[3] assert flags == 0x8000 def test_numbers(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - bytes = generated.packet() + bytes = generated.packets()[0] (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) assert num_questions == 0 assert num_answers == 0 @@ -411,7 +411,7 @@ def test_numbers_questions(self): question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) for i in range(10): generated.add_question(question) - bytes = generated.packet() + bytes = generated.packets()[0] (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) assert num_questions == 10 assert num_answers == 0 @@ -422,7 +422,7 @@ def test_numbers_questions(self): class TestDnsIncoming(unittest.TestCase): def test_incoming_exception_handling(self): generated = r.DNSOutgoing(0) - packet = generated.packet() + packet = generated.packets()[0] packet = packet[:8] + b'deadbeef' + packet[8:] parsed = r.DNSIncoming(packet) parsed = r.DNSIncoming(packet) @@ -432,7 +432,7 @@ def test_incoming_unknown_type(self): generated = r.DNSOutgoing(0) answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') generated.add_additional_answer(answer) - packet = generated.packet() + packet = generated.packets()[0] parsed = r.DNSIncoming(packet) assert len(parsed.answers) == 0 assert parsed.is_query() != parsed.is_response() @@ -443,7 +443,7 @@ def test_incoming_ipv6(self): generated = r.DNSOutgoing(0) answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) generated.add_additional_answer(answer) - packet = generated.packet() + packet = generated.packets()[0] parsed = r.DNSIncoming(packet) record = parsed.answers[0] assert isinstance(record, r.DNSAddress) diff --git a/tests/test_init.py b/tests/test_init.py index f89f786e..87105598 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -43,28 +43,28 @@ def test_long_name(self): "this.is.a.very.long.name.with.lots.of.parts.in.it.local.", const._TYPE_SRV, const._CLASS_IN ) generated.add_question(question) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_exceedingly_long_name(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) name = "%slocal." % ("part." * 1000) question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_extra_exceedingly_long_name(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) name = "%slocal." % ("part." * 4000) question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_exceedingly_long_name_part(self): name = "%s.local." % ("a" * 1000) generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) - self.assertRaises(r.NamePartTooLongException, generated.packet) + self.assertRaises(r.NamePartTooLongException, generated.packets) def test_same_name(self): name = "paired.local." @@ -72,7 +72,7 @@ def test_same_name(self): question = r.DNSQuestion(name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) generated.add_question(question) - r.DNSIncoming(generated.packet()) + r.DNSIncoming(generated.packets()[0]) def test_lots_of_names(self): @@ -156,7 +156,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): assert mocked_log_warn.call_count == call_counts[0] # force a receive of a packet - packet = out.packet() + packet = out.packets()[0] s = zc._respond_sockets[0] # mock the zeroconf logger and check for the correct logging backoff @@ -317,7 +317,7 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -348,7 +348,7 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packet()), False)) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -860,7 +860,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) @@ -1021,7 +1021,7 @@ def test_ptr_optimization(): # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - out = zc.query_handler.response(r.DNSIncoming(query.packet()), False) + out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False) assert out is not None nbr_answers += len(out.answers) nbr_authorities += len(out.authorities) diff --git a/tests/test_services.py b/tests/test_services.py index 07554eb6..d598d166 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -229,7 +229,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming: for record in records: generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) def get_service_info_helper(zc, type, name): nonlocal service_info @@ -366,7 +366,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming: for record in records: generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) def get_service_info_helper(zc, type, name): nonlocal service_info @@ -466,7 +466,7 @@ def mock_incoming_msg( generated.add_answer_at_time( r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 ) - return r.DNSIncoming(generated.packet()) + return r.DNSIncoming(generated.packets()[0]) zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) @@ -638,7 +638,7 @@ def current_time_millis(): def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" - pout = r.DNSIncoming(out.packet()) + pout = r.DNSIncoming(out.packets()[0]) nonlocal nbr_answers for answer in pout.answers: nbr_answers += 1 diff --git a/zeroconf/dns.py b/zeroconf/dns.py index c5139dac..430369f0 100644 --- a/zeroconf/dns.py +++ b/zeroconf/dns.py @@ -799,21 +799,6 @@ def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) del self.names[name] return False - def packet(self) -> bytes: - """Returns a bytestring containing the first packet's bytes. - - Generally, you want to use packets() in case the response - does not fit in a single packet, but this exists for - backward compatibility.""" - packets = self.packets() - if len(packets) == 0: - return b'' - if len(packets[0]) > _MAX_MSG_ABSOLUTE: - QuietLogger.log_warning_once( - "Created over-sized packet (%d bytes) %r", len(packets[0]), packets[0] - ) - return packets[0] - def _write_questions_from_offset(self, questions_offset: int) -> int: questions_written = 0 for question in self.questions[questions_offset:]: From ae552e94732568fd798e1f2d0e811849edff7790 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 00:19:21 -1000 Subject: [PATCH 257/608] Relocate services tests to test_services (#570) --- tests/test_init.py | 408 ---------------------------------------- tests/test_services.py | 411 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 410 insertions(+), 409 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index 87105598..eedb7fbb 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -588,414 +588,6 @@ def test_integration_with_subtype_and_listener(self): zeroconf_registrar.close() -class ListenerTest(unittest.TestCase): - def test_integration_with_listener_class(self): - - service_added = Event() - service_removed = Event() - service_updated = Event() - service_updated2 = Event() - - subtype_name = "My special Subtype" - type_ = "_http._tcp.local." - subtype = subtype_name + "._sub." + type_ - name = "UPPERxxxyyyæøå" - registration_name = "%s.%s" % (name, subtype) - - class MyListener(r.ServiceListener): - def add_service(self, zeroconf, type, name): - zeroconf.get_service_info(type, name) - service_added.set() - - def remove_service(self, zeroconf, type, name): - service_removed.set() - - def update_service(self, zeroconf, type, name): - service_updated2.set() - - class MySubListener(r.ServiceListener): - def add_service(self, zeroconf, type, name): - pass - - def remove_service(self, zeroconf, type, name): - pass - - def update_service(self, zeroconf, type, name): - service_updated.set() - - listener = MyListener() - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - zeroconf_browser.add_service_listener(subtype, listener) - - properties = dict( - prop_none=None, - prop_string=b'a_prop', - prop_float=1.0, - prop_blank=b'a blanked string', - prop_true=1, - prop_false=0, - ) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} # type: Dict - desc.update(properties) - addresses = [socket.inet_aton("10.0.1.2")] - if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): - addresses.append(socket.inet_pton(socket.AF_INET6, "6001:db8::1")) - addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) - info_service = ServiceInfo( - subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses - ) - zeroconf_registrar.register_service(info_service) - - try: - service_added.wait(1) - assert service_added.is_set() - - # short pause to allow multicast timers to expire - time.sleep(3) - - # clear the answer cache to force query - _clear_cache(zeroconf_browser) - - cached_info = ServiceInfo(type_, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties == {} - - # get service info without answer cache - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_none'] is None - assert info.properties[b'prop_string'] == properties['prop_string'] - assert info.properties[b'prop_float'] == b'1.0' - assert info.properties[b'prop_blank'] == properties['prop_blank'] - assert info.properties[b'prop_true'] == b'1' - assert info.properties[b'prop_false'] == b'0' - assert info.addresses == addresses[:1] # no V6 by default - assert info.addresses_by_version(r.IPVersion.All) == addresses - - cached_info = ServiceInfo(type_, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - - # Populate the cache - zeroconf_browser.get_service_info(subtype, registration_name) - - # get service info with only the cache - cached_info = ServiceInfo(subtype, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_float'] == b'1.0' - - # get service info with only the cache with the lowercase name - cached_info = ServiceInfo(subtype, registration_name.lower()) - cached_info.load_from_cache(zeroconf_browser) - # Ensure uppercase output is preserved - assert cached_info.name == registration_name - assert cached_info.key == registration_name.lower() - assert cached_info.properties is not None - assert cached_info.properties[b'prop_float'] == b'1.0' - - info = zeroconf_browser.get_service_info(subtype, registration_name) - assert info is not None - assert info.properties is not None - assert info.properties[b'prop_none'] is None - - cached_info = ServiceInfo(subtype, registration_name.lower()) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_none'] is None - - # test TXT record update - sublistener = MySubListener() - zeroconf_browser.add_service_listener(registration_name, sublistener) - properties['prop_blank'] = b'an updated string' - desc.update(properties) - info_service = ServiceInfo( - subtype, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - zeroconf_registrar.update_service(info_service) - service_updated.wait(1) - assert service_updated.is_set() - - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_blank'] == properties['prop_blank'] - - cached_info = ServiceInfo(subtype, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_blank'] == properties['prop_blank'] - - zeroconf_registrar.unregister_service(info_service) - service_removed.wait(1) - assert service_removed.is_set() - - finally: - zeroconf_registrar.close() - zeroconf_browser.remove_service_listener(listener) - zeroconf_browser.close() - - -class TestServiceBrowser(unittest.TestCase): - def test_update_record(self): - enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - service_v6_address = "2001:db8::1" - service_v6_second_address = "6001:db8::1" - - service_added_count = 0 - service_removed_count = 0 - service_updated_count = 0 - service_add_event = Event() - service_removed_event = Event() - service_updated_event = Event() - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal service_added_count - service_added_count += 1 - service_add_event.set() - - def remove_service(self, zc, type_, name) -> None: - nonlocal service_removed_count - service_removed_count += 1 - service_removed_event.set() - - def update_service(self, zc, type_, name) -> None: - nonlocal service_updated_count - service_updated_count += 1 - service_info = zc.get_service_info(type_, name) - assert socket.inet_aton(service_address) in service_info.addresses - if enable_ipv6: - assert socket.inet_pton( - socket.AF_INET6, service_v6_address - ) in service_info.addresses_by_version(r.IPVersion.V6Only) - assert socket.inet_pton( - socket.AF_INET6, service_v6_second_address - ) in service_info.addresses_by_version(r.IPVersion.V6Only) - assert service_info.text == service_text - assert service_info.server == service_server - service_updated_event.set() - - def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - assert generated.is_response() is True - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - else: - ttl = 120 - - generated.add_answer_at_time( - r.DNSText( - service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text - ), - 0, - ) - - generated.add_answer_at_time( - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ), - 0, - ) - - # Send the IPv6 address first since we previously - # had a bug where the IPv4 would be missing if the - # IPv6 was seen first - if enable_ipv6: - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_AAAA, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET6, service_v6_address), - ), - 0, - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_AAAA, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET6, service_v6_second_address), - ), - 0, - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - - generated.add_answer_at_time( - r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 - ) - - return r.DNSIncoming(generated.packets()[0]) - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) - - try: - wait_time = 3 - - # service added - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) - service_add_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 0 - assert service_removed_count == 0 - - # service SRV updated - service_updated_event.clear() - service_server = 'ash-2.local.' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 1 - assert service_removed_count == 0 - - # service TXT updated - service_updated_event.clear() - service_text = b'path=/~matt2/' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 2 - assert service_removed_count == 0 - - # service TXT updated - duplicate update should not trigger another service_updated - service_updated_event.clear() - service_text = b'path=/~matt2/' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 2 - assert service_removed_count == 0 - - # service A updated - service_updated_event.clear() - service_address = '10.0.1.3' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 3 - assert service_removed_count == 0 - - # service all updated - service_updated_event.clear() - service_server = 'ash-3.local.' - service_text = b'path=/~matt3/' - service_address = '10.0.1.3' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 4 - assert service_removed_count == 0 - - # service removed - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) - service_removed_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 4 - assert service_removed_count == 1 - - finally: - assert len(zeroconf.listeners) == 1 - service_browser.cancel() - assert len(zeroconf.listeners) == 0 - zeroconf.remove_all_service_listeners() - zeroconf.close() - - -def test_multiple_addresses(): - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - - # New kwarg way - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) - - assert info.addresses == [address, address] - - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - parsed_addresses=[address_parsed, address_parsed], - ) - assert info.addresses == [address, address] - - if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): - address_v6_parsed = "2001:db8::1" - address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) - infos = [ - ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[address, address_v6], - ), - ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - parsed_addresses=[address_parsed, address_v6_parsed], - ), - ] - for info in infos: - assert info.addresses == [address] - assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] - assert info.addresses_by_version(r.IPVersion.V4Only) == [address] - assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] - assert info.parsed_addresses() == [address_parsed, address_v6_parsed] - assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] - - def test_ptr_optimization(): # instantiate a zeroconf instance diff --git a/tests/test_services.py b/tests/test_services.py index d598d166..243662fe 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -8,6 +8,7 @@ import socket import threading import time +import os import unittest from threading import Event @@ -23,7 +24,7 @@ ServiceStateChange, ) -from . import _inject_response +from . import has_working_ipv6, _clear_cache, _inject_response log = logging.getLogger('zeroconf') @@ -537,6 +538,414 @@ def _mock_get_expiration_time(self, percent): zeroconf.close() +class ListenerTest(unittest.TestCase): + def test_integration_with_listener_class(self): + + service_added = Event() + service_removed = Event() + service_updated = Event() + service_updated2 = Event() + + subtype_name = "My special Subtype" + type_ = "_http._tcp.local." + subtype = subtype_name + "._sub." + type_ + name = "UPPERxxxyyyæøå" + registration_name = "%s.%s" % (name, subtype) + + class MyListener(r.ServiceListener): + def add_service(self, zeroconf, type, name): + zeroconf.get_service_info(type, name) + service_added.set() + + def remove_service(self, zeroconf, type, name): + service_removed.set() + + def update_service(self, zeroconf, type, name): + service_updated2.set() + + class MySubListener(r.ServiceListener): + def add_service(self, zeroconf, type, name): + pass + + def remove_service(self, zeroconf, type, name): + pass + + def update_service(self, zeroconf, type, name): + service_updated.set() + + listener = MyListener() + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + zeroconf_browser.add_service_listener(subtype, listener) + + properties = dict( + prop_none=None, + prop_string=b'a_prop', + prop_float=1.0, + prop_blank=b'a blanked string', + prop_true=1, + prop_false=0, + ) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} # type: Dict + desc.update(properties) + addresses = [socket.inet_aton("10.0.1.2")] + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): + addresses.append(socket.inet_pton(socket.AF_INET6, "6001:db8::1")) + addresses.append(socket.inet_pton(socket.AF_INET6, "2001:db8::1")) + info_service = ServiceInfo( + subtype, registration_name, port=80, properties=desc, server="ash-2.local.", addresses=addresses + ) + zeroconf_registrar.register_service(info_service) + + try: + service_added.wait(1) + assert service_added.is_set() + + # short pause to allow multicast timers to expire + time.sleep(3) + + # clear the answer cache to force query + _clear_cache(zeroconf_browser) + + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties == {} + + # get service info without answer cache + info = zeroconf_browser.get_service_info(type_, registration_name) + assert info is not None + assert info.properties[b'prop_none'] is None + assert info.properties[b'prop_string'] == properties['prop_string'] + assert info.properties[b'prop_float'] == b'1.0' + assert info.properties[b'prop_blank'] == properties['prop_blank'] + assert info.properties[b'prop_true'] == b'1' + assert info.properties[b'prop_false'] == b'0' + assert info.addresses == addresses[:1] # no V6 by default + assert info.addresses_by_version(r.IPVersion.All) == addresses + + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + + # Populate the cache + zeroconf_browser.get_service_info(subtype, registration_name) + + # get service info with only the cache + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + + # get service info with only the cache with the lowercase name + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + # Ensure uppercase output is preserved + assert cached_info.name == registration_name + assert cached_info.key == registration_name.lower() + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + + info = zeroconf_browser.get_service_info(subtype, registration_name) + assert info is not None + assert info.properties is not None + assert info.properties[b'prop_none'] is None + + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_none'] is None + + # test TXT record update + sublistener = MySubListener() + zeroconf_browser.add_service_listener(registration_name, sublistener) + properties['prop_blank'] = b'an updated string' + desc.update(properties) + info_service = ServiceInfo( + subtype, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.update_service(info_service) + service_updated.wait(1) + assert service_updated.is_set() + + info = zeroconf_browser.get_service_info(type_, registration_name) + assert info is not None + assert info.properties[b'prop_blank'] == properties['prop_blank'] + + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_blank'] == properties['prop_blank'] + + zeroconf_registrar.unregister_service(info_service) + service_removed.wait(1) + assert service_removed.is_set() + + finally: + zeroconf_registrar.close() + zeroconf_browser.remove_service_listener(listener) + zeroconf_browser.close() + + +class TestServiceBrowser(unittest.TestCase): + def test_update_record(self): + enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + service_v6_address = "2001:db8::1" + service_v6_second_address = "6001:db8::1" + + service_added_count = 0 + service_removed_count = 0 + service_updated_count = 0 + service_add_event = Event() + service_removed_event = Event() + service_updated_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + service_removed_event.set() + + def update_service(self, zc, type_, name) -> None: + nonlocal service_updated_count + service_updated_count += 1 + service_info = zc.get_service_info(type_, name) + assert socket.inet_aton(service_address) in service_info.addresses + if enable_ipv6: + assert socket.inet_pton( + socket.AF_INET6, service_v6_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert socket.inet_pton( + socket.AF_INET6, service_v6_second_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert service_info.text == service_text + assert service_info.server == service_server + service_updated_event.set() + + def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + assert generated.is_response() is True + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + else: + ttl = 120 + + generated.add_answer_at_time( + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + 0, + ) + + # Send the IPv6 address first since we previously + # had a bug where the IPv4 would be missing if the + # IPv6 was seen first + if enable_ipv6: + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_second_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + + return r.DNSIncoming(generated.packets()[0]) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) + + try: + wait_time = 3 + + # service added + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) + service_add_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 0 + assert service_removed_count == 0 + + # service SRV updated + service_updated_event.clear() + service_server = 'ash-2.local.' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 1 + assert service_removed_count == 0 + + # service TXT updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service TXT updated - duplicate update should not trigger another service_updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service A updated + service_updated_event.clear() + service_address = '10.0.1.3' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 3 + assert service_removed_count == 0 + + # service all updated + service_updated_event.clear() + service_server = 'ash-3.local.' + service_text = b'path=/~matt3/' + service_address = '10.0.1.3' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 0 + + # service removed + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) + service_removed_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 1 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + +def test_multiple_addresses(): + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) + + assert info.addresses == [address, address] + + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_parsed], + ) + assert info.addresses == [address, address] + + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): + address_v6_parsed = "2001:db8::1" + address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) + infos = [ + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[address, address_v6], + ), + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_v6_parsed], + ), + ] + for info in infos: + assert info.addresses == [address] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.V4Only) == [address] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed] + assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] + + def test_backoff(): got_query = Event() From f10a562471ad89527e6eef6ba935a27177bb1417 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 00:46:35 -1000 Subject: [PATCH 258/608] Update changelog (#573) --- README.rst | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/README.rst b/README.rst index 228f0fac..dc944425 100644 --- a/README.rst +++ b/README.rst @@ -189,6 +189,65 @@ Changelog The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. +* Breaking change: Remove DNSOutgoing.packet backwards compatibility (#569) @bdraco + + DNSOutgoing.packet only returned a partial message when the + DNSOutgoing contents exceeded _MAX_MSG_ABSOLUTE or _MAX_MSG_TYPICAL + This was a legacy function that was replaced with .packets() + which always returns a complete payload in #248 As packet() + should not be used since it will end up missing data, it has + been removed + +* Breakout DNSCache into zeroconf.cache (#568) @bdraco + +* Removed protected imports from zeroconf namespace (#567) @bdraco + +* Fix invalid typing in ServiceInfo._set_text (#554) @bdraco + +* Move QueryHandler and RecordManager handlers into zeroconf.handlers (#551) @bdraco + +* Move ServiceListener to zeroconf.services (#550) @bdraco + +* Move the ServiceRegistry into its own module (#549) @bdraco + +* Move ServiceStateChange to zeroconf.services (#548) @bdraco + +* Relocate core functions into zeroconf.core (#547) @bdraco + +* Breakout service classes into zeroconf.services (#544) @bdraco + +* Move service_type_name to zeroconf.utils.name (#543) @bdraco + +* Relocate DNS classes to zeroconf.dns (#541) @bdraco + +* Update zeroconf.aio import locations (#539) @bdraco + +* Move int2byte to zeroconf.utils.struct (#540) @bdraco + +* Breakout network utils into zeroconf.utils.net (#537) @bdraco + +* Move time utility functions into zeroconf.utils.time (#536) @bdraco + +* Avoid making DNSOutgoing aware of the Zeroconf object (#535) @bdraco + +* Move logger into zeroconf.logger (#533) @bdraco + +* Move exceptions into zeroconf.exceptions (#532) @bdraco + +* Move constants into const.py (#531) @bdraco + +* Move asyncio utils into zeroconf.utils.aio (#530) @bdraco + +* Move ipversion auto detection code into its own function (#524) @bdraco + +* Breaking change: Update python compatibility as PyPy3 7.2 is required (#523) @bdraco + +* Remove broad exception catch from RecordManager.remove_listener (#517) @bdraco + +* Small cleanups to RecordManager.add_listener (#516) @bdraco + +* Move RecordUpdateListener management into RecordManager (#514) @bdraco + * Break out record updating into RecordManager (#512) @bdraco * Remove uneeded wait in the Engine thread (#511) @bdraco From 0e61b1502c7fd3412f979bc4d651ee016e712de9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 01:30:23 -1000 Subject: [PATCH 259/608] Mark zeroconf.dns as protected by renaming to zeroconf._dns (#574) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- tests/__init__.py | 2 +- zeroconf/__init__.py | 2 +- zeroconf/{dns.py => _dns.py} | 0 zeroconf/aio.py | 2 +- zeroconf/cache.py | 2 +- zeroconf/core.py | 2 +- zeroconf/handlers.py | 2 +- zeroconf/services/__init__.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename zeroconf/{dns.py => _dns.py} (100%) diff --git a/tests/__init__.py b/tests/__init__.py index 6399dbef..237bea3a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,7 +28,7 @@ from zeroconf.core import Zeroconf -from zeroconf.dns import DNSIncoming +from zeroconf._dns import DNSIncoming def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index aae68e4e..cee8d28a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -24,7 +24,7 @@ from .cache import DNSCache # noqa # import needed for backwards compat from .core import NotifyListener, Zeroconf # noqa # import needed for backwards compat -from .dns import ( # noqa # import needed for backwards compat +from ._dns import ( # noqa # import needed for backwards compat DNSAddress, DNSEntry, DNSHinfo, diff --git a/zeroconf/dns.py b/zeroconf/_dns.py similarity index 100% rename from zeroconf/dns.py rename to zeroconf/_dns.py diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 82e86199..3bfdce17 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,9 +26,9 @@ from types import TracebackType # noqa # used in type hints from typing import Awaitable, Callable, Dict, List, Optional, Type, Union +from ._dns import DNSOutgoing from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME from .core import NotifyListener, Zeroconf -from .dns import DNSOutgoing from .exceptions import NonUniqueNameException from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from .utils.aio import wait_condition_or_timeout diff --git a/zeroconf/cache.py b/zeroconf/cache.py index 48750f5a..cb54341e 100644 --- a/zeroconf/cache.py +++ b/zeroconf/cache.py @@ -22,8 +22,8 @@ from typing import Dict, Iterable, List, Optional, cast +from ._dns import DNSEntry, DNSPointer, DNSRecord, DNSService from .const import _TYPE_PTR -from .dns import DNSEntry, DNSPointer, DNSRecord, DNSService from .utils.time import current_time_millis diff --git a/zeroconf/core.py b/zeroconf/core.py index 23c3583e..4ab97178 100644 --- a/zeroconf/core.py +++ b/zeroconf/core.py @@ -28,6 +28,7 @@ from types import TracebackType # noqa # used in type hints from typing import Dict, List, Optional, Type, Union, cast +from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion from .cache import DNSCache from .const import ( _CACHE_CLEANUP_INTERVAL, @@ -45,7 +46,6 @@ _TYPE_PTR, _UNREGISTER_TIME, ) -from .dns import DNSIncoming, DNSOutgoing, DNSQuestion from .exceptions import NonUniqueNameException from .handlers import QueryHandler, RecordManager from .logger import QuietLogger, log diff --git a/zeroconf/handlers.py b/zeroconf/handlers.py index 000bc908..2e2e3736 100644 --- a/zeroconf/handlers.py +++ b/zeroconf/handlers.py @@ -23,6 +23,7 @@ import itertools from typing import List, Optional, TYPE_CHECKING, Union +from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from .const import ( _CLASS_IN, _DNS_OTHER_TTL, @@ -35,7 +36,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from .dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from .logger import log from .services import ( RecordUpdateListener, diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index cd84971d..f9ad4d25 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -27,6 +27,7 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from .._dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -46,7 +47,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from ..dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ..exceptions import BadTypeInNameException from ..utils.name import service_type_name from ..utils.net import ( From 601e8f70499638a6f24291bc0a28054fd78243c0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 01:40:06 -1000 Subject: [PATCH 260/608] Mark zeroconf.core as protected by renaming to zeroconf._core (#575) --- tests/__init__.py | 3 +-- tests/test_aio.py | 2 +- tests/test_core.py | 6 +++--- tests/test_services.py | 2 +- zeroconf/__init__.py | 2 +- zeroconf/{core.py => _core.py} | 0 zeroconf/aio.py | 2 +- zeroconf/handlers.py | 2 +- zeroconf/services/__init__.py | 2 +- zeroconf/services/types.py | 2 +- 10 files changed, 11 insertions(+), 12 deletions(-) rename zeroconf/{core.py => _core.py} (100%) diff --git a/tests/__init__.py b/tests/__init__.py index 237bea3a..420541d7 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -27,8 +27,7 @@ import ifaddr -from zeroconf.core import Zeroconf -from zeroconf._dns import DNSIncoming +from zeroconf import DNSIncoming, Zeroconf def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: diff --git a/tests/test_aio.py b/tests/test_aio.py index b1be151d..d197570c 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -12,7 +12,7 @@ import pytest from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf -from zeroconf.core import Zeroconf +from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf.exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf.services import ServiceInfo, ServiceListener diff --git a/tests/test_core.py b/tests/test_core.py index abdac3b9..a99519cb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -14,7 +14,7 @@ from typing import cast import zeroconf as r -from zeroconf import core +from zeroconf import _core from zeroconf import const from . import has_working_ipv6, _inject_response @@ -35,9 +35,9 @@ def teardown_module(): class TestReaper(unittest.TestCase): - @unittest.mock.patch.object(core, "_CACHE_CLEANUP_INTERVAL", 10) + @unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10) def test_reaper(self): - zeroconf = core.Zeroconf(interfaces=['127.0.0.1']) + zeroconf = _core.Zeroconf(interfaces=['127.0.0.1']) cache = zeroconf.cache original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') diff --git a/tests/test_services.py b/tests/test_services.py index 243662fe..86f15ae7 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -17,7 +17,7 @@ import zeroconf as r from zeroconf import const import zeroconf.services as s -from zeroconf.core import Zeroconf +from zeroconf import Zeroconf from zeroconf.services import ( ServiceBrowser, ServiceInfo, diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index cee8d28a..cdac526b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -23,7 +23,7 @@ import sys from .cache import DNSCache # noqa # import needed for backwards compat -from .core import NotifyListener, Zeroconf # noqa # import needed for backwards compat +from ._core import NotifyListener, Zeroconf # noqa # import needed for backwards compat from ._dns import ( # noqa # import needed for backwards compat DNSAddress, DNSEntry, diff --git a/zeroconf/core.py b/zeroconf/_core.py similarity index 100% rename from zeroconf/core.py rename to zeroconf/_core.py diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 3bfdce17..3c503a46 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,9 +26,9 @@ from types import TracebackType # noqa # used in type hints from typing import Awaitable, Callable, Dict, List, Optional, Type, Union +from ._core import NotifyListener, Zeroconf from ._dns import DNSOutgoing from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME -from .core import NotifyListener, Zeroconf from .exceptions import NonUniqueNameException from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from .utils.aio import wait_condition_or_timeout diff --git a/zeroconf/handlers.py b/zeroconf/handlers.py index 2e2e3736..a1de5eb0 100644 --- a/zeroconf/handlers.py +++ b/zeroconf/handlers.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from .core import Zeroconf # pylint: disable=cyclic-import + from ._core import Zeroconf # pylint: disable=cyclic-import class QueryHandler: diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index f9ad4d25..cf8b0686 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -59,7 +59,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from ..core import Zeroconf # pylint: disable=cyclic-import + from .._core import Zeroconf # pylint: disable=cyclic-import @enum.unique diff --git a/zeroconf/services/types.py b/zeroconf/services/types.py index e27defff..d6cc1e97 100644 --- a/zeroconf/services/types.py +++ b/zeroconf/services/types.py @@ -23,8 +23,8 @@ import time from typing import Optional, Set, Tuple, Union +from .._core import Zeroconf from ..const import _SERVICE_TYPE_ENUMERATION_NAME -from ..core import Zeroconf from ..services import ServiceBrowser, ServiceListener from ..utils.net import IPVersion, InterfaceChoice, InterfacesType From c29a235eb59ed3b4883305cf11f8bf9fa06284d3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 02:01:46 -1000 Subject: [PATCH 261/608] Log zeroconf.asyncio deprecation warning with the logger module (#576) --- zeroconf/asyncio.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index bdca1c0d..0a0457e5 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -20,11 +20,8 @@ USA """ -import logging - from .aio import AsyncZeroconf # pylint: disable=unused-import # noqa - -log = logging.getLogger(__name__) +from .logger import log # The asyncio module would shadow system asyncio in some import cases # to resolve this, the module has been renamed zeroconf.aio From 1a2ee6892e996c1e84ba97082e5cda609d1d55d7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 08:55:31 -1000 Subject: [PATCH 262/608] Mark zeroconf.handlers as protected by renaming to zeroconf._handlers (#577) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- zeroconf/_core.py | 2 +- zeroconf/{handlers.py => _handlers.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename zeroconf/{handlers.py => _handlers.py} (100%) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 4ab97178..14679349 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -29,6 +29,7 @@ from typing import Dict, List, Optional, Type, Union, cast from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion +from ._handlers import QueryHandler, RecordManager from .cache import DNSCache from .const import ( _CACHE_CLEANUP_INTERVAL, @@ -47,7 +48,6 @@ _UNREGISTER_TIME, ) from .exceptions import NonUniqueNameException -from .handlers import QueryHandler, RecordManager from .logger import QuietLogger, log from .services import ( RecordUpdateListener, diff --git a/zeroconf/handlers.py b/zeroconf/_handlers.py similarity index 100% rename from zeroconf/handlers.py rename to zeroconf/_handlers.py From 500066f940aa89737f343976ee0387eae97eac37 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:15:02 -1000 Subject: [PATCH 263/608] Mark zeroconf.logger as protected by renaming to zeroconf._logger (#578) --- tests/test_init.py | 4 ++-- tests/test_logger.py | 18 +++++++++--------- zeroconf/__init__.py | 2 +- zeroconf/_core.py | 2 +- zeroconf/_dns.py | 2 +- zeroconf/_handlers.py | 2 +- zeroconf/{logger.py => _logger.py} | 0 zeroconf/asyncio.py | 2 +- zeroconf/utils/net.py | 2 +- 9 files changed, 17 insertions(+), 17 deletions(-) rename zeroconf/{logger.py => _logger.py} (100%) diff --git a/tests/test_init.py b/tests/test_init.py index eedb7fbb..ce424857 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -136,8 +136,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # mock zeroconf's logger warning() and debug() from unittest.mock import patch - patch_warn = patch('zeroconf.log.warning') - patch_debug = patch('zeroconf.log.debug') + patch_warn = patch('zeroconf._logger.log.warning') + patch_debug = patch('zeroconf._logger.log.debug') mocked_log_warn = patch_warn.start() mocked_log_debug = patch_debug.start() diff --git a/tests/test_logger.py b/tests/test_logger.py index 52bf830f..2c661cf9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,22 +5,22 @@ """Unit tests for logger.py.""" from unittest.mock import patch -from zeroconf.logger import QuietLogger +from zeroconf._logger import QuietLogger def test_log_warning_once(): """Test we only log with warning level once.""" quiet_logger = QuietLogger() - with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( - "zeroconf.logger.log.debug" + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" ) as mock_log_debug: quiet_logger.log_warning_once("the warning") assert mock_log_warning.mock_calls assert not mock_log_debug.mock_calls - with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( - "zeroconf.logger.log.debug" + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" ) as mock_log_debug: quiet_logger.log_warning_once("the warning") @@ -31,16 +31,16 @@ def test_log_warning_once(): def test_log_exception_warning(): """Test we only log with warning level once.""" quiet_logger = QuietLogger() - with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( - "zeroconf.logger.log.debug" + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" ) as mock_log_debug: quiet_logger.log_exception_warning("the exception warning") assert mock_log_warning.mock_calls assert not mock_log_debug.mock_calls - with patch("zeroconf.logger.log.warning") as mock_log_warning, patch( - "zeroconf.logger.log.debug" + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" ) as mock_log_debug: quiet_logger.log_exception_warning("the exception warning") diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index cdac526b..794f95d2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -36,6 +36,7 @@ DNSService, DNSText, ) +from ._logger import QuietLogger, log # noqa # import needed for backwards compat from .exceptions import ( # noqa # import needed for backwards compat AbstractMethodException, BadTypeInNameException, @@ -45,7 +46,6 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) -from .logger import QuietLogger, log # noqa # import needed for backwards compat from .services import ( # noqa # import needed for backwards compat instance_name_from_service_info, Signal, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 14679349..64b365ef 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -30,6 +30,7 @@ from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion from ._handlers import QueryHandler, RecordManager +from ._logger import QuietLogger, log from .cache import DNSCache from .const import ( _CACHE_CLEANUP_INTERVAL, @@ -48,7 +49,6 @@ _UNREGISTER_TIME, ) from .exceptions import NonUniqueNameException -from .logger import QuietLogger, log from .services import ( RecordUpdateListener, ServiceBrowser, diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 430369f0..d9f91a33 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -25,6 +25,7 @@ import struct from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast +from ._logger import QuietLogger, log from .const import ( _CLASSES, _CLASS_MASK, @@ -48,7 +49,6 @@ _TYPE_TXT, ) from .exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException -from .logger import QuietLogger, log from .utils.net import _is_v6_address from .utils.struct import int2byte from .utils.time import current_time_millis, millis_to_seconds diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index a1de5eb0..f3397fa9 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -24,6 +24,7 @@ from typing import List, Optional, TYPE_CHECKING, Union from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord +from ._logger import log from .const import ( _CLASS_IN, _DNS_OTHER_TTL, @@ -36,7 +37,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from .logger import log from .services import ( RecordUpdateListener, ) diff --git a/zeroconf/logger.py b/zeroconf/_logger.py similarity index 100% rename from zeroconf/logger.py rename to zeroconf/_logger.py diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 0a0457e5..3de171f7 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -20,8 +20,8 @@ USA """ +from ._logger import log from .aio import AsyncZeroconf # pylint: disable=unused-import # noqa -from .logger import log # The asyncio module would shadow system asyncio in some import cases # to resolve this, the module has been renamed zeroconf.aio diff --git a/zeroconf/utils/net.py b/zeroconf/utils/net.py index 5ea49924..963faf55 100644 --- a/zeroconf/utils/net.py +++ b/zeroconf/utils/net.py @@ -30,8 +30,8 @@ import ifaddr +from .._logger import log from ..const import _IPPROTO_IPV6, _MDNS_ADDR6_BYTES, _MDNS_ADDR_BYTES, _MDNS_PORT -from ..logger import log @enum.unique From dd9ada781fdb1d5efc7c6ad194426e92550245b1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:22:45 -1000 Subject: [PATCH 264/608] Fix flakey backoff test race on startup (#579) --- tests/test_services.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 86f15ae7..090fc736 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -989,9 +989,14 @@ def on_service_state_change(zeroconf, service_type, state_change, name): next_query_interval = 0.0 expected_query_time = 0.0 while True: - zeroconf_browser.notify_all() sleep_count += 1 - got_query.wait(0.1) + for _ in range(2): + # If the browser thread is starting up + # its possible we notify before the initial sleep + # which means the test will fail so we need to d + # this twice to eliminate the race condition + zeroconf_browser.notify_all() + got_query.wait(0.05) if time_offset == expected_query_time: assert got_query.is_set() got_query.clear() From 241700a07a76a8c45afbe1bdd8325cd9f0eb0168 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:25:11 -1000 Subject: [PATCH 265/608] Mark zeroconf.exceptions as protected by renaming to zeroconf._exceptions (#580) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- tests/test_aio.py | 2 +- tests/test_exceptions.py | 2 +- zeroconf/__init__.py | 2 +- zeroconf/_core.py | 2 +- zeroconf/_dns.py | 2 +- zeroconf/{exceptions.py => _exceptions.py} | 2 -- zeroconf/aio.py | 2 +- zeroconf/services/__init__.py | 2 +- zeroconf/services/registry.py | 2 +- zeroconf/utils/name.py | 2 +- 10 files changed, 9 insertions(+), 11 deletions(-) rename zeroconf/{exceptions.py => _exceptions.py} (98%) diff --git a/tests/test_aio.py b/tests/test_aio.py index d197570c..962e0df8 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -14,7 +14,7 @@ from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME -from zeroconf.exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered +from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf.services import ServiceInfo, ServiceListener from zeroconf.utils.time import current_time_millis diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index c85da045..cfc4c19d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- -""" Unit tests for zeroconf.exceptions """ +""" Unit tests for zeroconf._exceptions """ import logging import unittest diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 794f95d2..39f0c3d7 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -37,7 +37,7 @@ DNSText, ) from ._logger import QuietLogger, log # noqa # import needed for backwards compat -from .exceptions import ( # noqa # import needed for backwards compat +from ._exceptions import ( # noqa # import needed for backwards compat AbstractMethodException, BadTypeInNameException, Error, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 64b365ef..5df9291f 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -29,6 +29,7 @@ from typing import Dict, List, Optional, Type, Union, cast from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion +from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log from .cache import DNSCache @@ -48,7 +49,6 @@ _TYPE_PTR, _UNREGISTER_TIME, ) -from .exceptions import NonUniqueNameException from .services import ( RecordUpdateListener, ServiceBrowser, diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index d9f91a33..8deed48f 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -25,6 +25,7 @@ import struct from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast +from ._exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log from .const import ( _CLASSES, @@ -48,7 +49,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from .exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException from .utils.net import _is_v6_address from .utils.struct import int2byte from .utils.time import current_time_millis, millis_to_seconds diff --git a/zeroconf/exceptions.py b/zeroconf/_exceptions.py similarity index 98% rename from zeroconf/exceptions.py rename to zeroconf/_exceptions.py index ea468659..02771140 100644 --- a/zeroconf/exceptions.py +++ b/zeroconf/_exceptions.py @@ -20,8 +20,6 @@ USA """ -# Exceptions - class Error(Exception): pass diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 3c503a46..d8ed256e 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -28,8 +28,8 @@ from ._core import NotifyListener, Zeroconf from ._dns import DNSOutgoing +from ._exceptions import NonUniqueNameException from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME -from .exceptions import NonUniqueNameException from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from .utils.aio import wait_condition_or_timeout from .utils.net import IPVersion, InterfaceChoice, InterfacesType diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index cf8b0686..4fc62c87 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -28,6 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from .._dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from .._exceptions import BadTypeInNameException from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -47,7 +48,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from ..exceptions import BadTypeInNameException from ..utils.name import service_type_name from ..utils.net import ( IPVersion, diff --git a/zeroconf/services/registry.py b/zeroconf/services/registry.py index 19d4ba46..b17c4284 100644 --- a/zeroconf/services/registry.py +++ b/zeroconf/services/registry.py @@ -24,7 +24,7 @@ from typing import Dict, List, Optional -from ..exceptions import ServiceNameAlreadyRegistered +from .._exceptions import ServiceNameAlreadyRegistered from ..services import ServiceInfo diff --git a/zeroconf/utils/name.py b/zeroconf/utils/name.py index 65713eb0..10a0ccf8 100644 --- a/zeroconf/utils/name.py +++ b/zeroconf/utils/name.py @@ -20,6 +20,7 @@ USA """ +from .._exceptions import BadTypeInNameException from ..const import ( _HAS_ASCII_CONTROL_CHARS, _HAS_A_TO_Z, @@ -29,7 +30,6 @@ _NONTCP_PROTOCOL_LOCAL_TRAILER, _TCP_PROTOCOL_LOCAL_TRAILER, ) -from ..exceptions import BadTypeInNameException def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: disable=too-many-branches From a16e85b20c2069aa9cee0510c618cb61d46dc19c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:32:26 -1000 Subject: [PATCH 266/608] Mark zeroconf.cache as protected by renaming to zeroconf._cache (#581) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- zeroconf/__init__.py | 2 +- zeroconf/{cache.py => _cache.py} | 0 zeroconf/_core.py | 2 +- zeroconf/_dns.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename zeroconf/{cache.py => _cache.py} (100%) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 39f0c3d7..dec26353 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -22,7 +22,7 @@ import sys -from .cache import DNSCache # noqa # import needed for backwards compat +from ._cache import DNSCache # noqa # import needed for backwards compat from ._core import NotifyListener, Zeroconf # noqa # import needed for backwards compat from ._dns import ( # noqa # import needed for backwards compat DNSAddress, diff --git a/zeroconf/cache.py b/zeroconf/_cache.py similarity index 100% rename from zeroconf/cache.py rename to zeroconf/_cache.py diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 5df9291f..594754ee 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -28,11 +28,11 @@ from types import TracebackType # noqa # used in type hints from typing import Dict, List, Optional, Type, Union, cast +from ._cache import DNSCache from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log -from .cache import DNSCache from .const import ( _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 8deed48f..5e1c6067 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -56,7 +56,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from .cache import DNSCache # pylint: disable=cyclic-import + from ._cache import DNSCache # pylint: disable=cyclic-import class DNSEntry: From cc5bc36f6f7597a0adb0d637147c2f93ca243ff4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:39:38 -1000 Subject: [PATCH 267/608] Mark zeroconf.utils as protected by renaming to zeroconf._utils (#582) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- setup.py | 2 +- tests/test_aio.py | 2 +- tests/utils/test_aio.py | 4 ++-- tests/utils/test_net.py | 8 ++++---- zeroconf/__init__.py | 8 ++++---- zeroconf/_cache.py | 2 +- zeroconf/_core.py | 20 ++++++++++---------- zeroconf/_dns.py | 6 +++--- zeroconf/_handlers.py | 2 +- zeroconf/{utils => _utils}/__init__.py | 0 zeroconf/{utils => _utils}/aio.py | 0 zeroconf/{utils => _utils}/name.py | 0 zeroconf/{utils => _utils}/net.py | 0 zeroconf/{utils => _utils}/struct.py | 0 zeroconf/{utils => _utils}/time.py | 0 zeroconf/aio.py | 6 +++--- zeroconf/services/__init__.py | 17 +++++++++-------- zeroconf/services/types.py | 2 +- 18 files changed, 40 insertions(+), 39 deletions(-) rename zeroconf/{utils => _utils}/__init__.py (100%) rename zeroconf/{utils => _utils}/aio.py (100%) rename zeroconf/{utils => _utils}/name.py (100%) rename zeroconf/{utils => _utils}/net.py (100%) rename zeroconf/{utils => _utils}/struct.py (100%) rename zeroconf/{utils => _utils}/time.py (100%) diff --git a/setup.py b/setup.py index c1c0da34..f6f8582b 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ author='Paul Scott-Murphy, William McBrine, Jakub Stasiak', url='https://github.com/jstasiak/python-zeroconf', package_data={"zeroconf": ["py.typed"]}, - packages=["zeroconf", "zeroconf.services", "zeroconf.utils"], + packages=["zeroconf", "zeroconf.services", "zeroconf._utils"], platforms=['unix', 'linux', 'osx'], license='LGPL', zip_safe=False, diff --git a/tests/test_aio.py b/tests/test_aio.py index 962e0df8..5d04e019 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -16,7 +16,7 @@ from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf.services import ServiceInfo, ServiceListener -from zeroconf.utils.time import current_time_millis +from zeroconf._utils.time import current_time_millis @pytest.fixture(autouse=True) diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index e38eb583..a74d991d 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -2,13 +2,13 @@ # -*- coding: utf-8 -*- -"""Unit tests for zeroconf.utils.aio.""" +"""Unit tests for zeroconf._utils.aio.""" import asyncio import pytest -from zeroconf.utils import aio as aioutils +from zeroconf._utils import aio as aioutils @pytest.mark.asyncio diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index d4b829c2..1360f936 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -2,13 +2,13 @@ # -*- coding: utf-8 -*- -"""Unit tests for zeroconf.utils.net.""" +"""Unit tests for zeroconf._utils.net.""" from unittest.mock import Mock, patch import ifaddr import pytest -from zeroconf.utils import net as netutils +from zeroconf._utils import net as netutils def _generate_mock_adapters(): @@ -50,9 +50,9 @@ def test_interface_index_to_ip6_address(): def test_ip6_addresses_to_indexes(): """Test we can extract from mocked adapters.""" interfaces = [1] - with patch("zeroconf.utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] interfaces = ['2001:db8::'] - with patch("zeroconf.utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): + with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index dec26353..e424118d 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -58,8 +58,8 @@ ) from .services.registry import ServiceRegistry # noqa # import needed for backwards compat from .services.types import ZeroconfServiceTypes # noqa # import needed for backwards compat -from .utils.name import service_type_name # noqa # import needed for backwards compat -from .utils.net import ( # noqa # import needed for backwards compat +from ._utils.name import service_type_name # noqa # import needed for backwards compat +from ._utils.net import ( # noqa # import needed for backwards compat add_multicast_member, can_send_to, autodetect_ip_version, @@ -70,8 +70,8 @@ IPVersion, get_all_addresses, ) -from .utils.struct import int2byte # noqa # import needed for backwards compat -from .utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat +from ._utils.struct import int2byte # noqa # import needed for backwards compat +from ._utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index cb54341e..135b1884 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -23,8 +23,8 @@ from typing import Dict, Iterable, List, Optional, cast from ._dns import DNSEntry, DNSPointer, DNSRecord, DNSService +from ._utils.time import current_time_millis from .const import _TYPE_PTR -from .utils.time import current_time_millis class DNSCache: diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 594754ee..a2547305 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -33,6 +33,16 @@ from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log +from ._utils.name import service_type_name +from ._utils.net import ( + IPVersion, + InterfaceChoice, + InterfacesType, + autodetect_ip_version, + can_send_to, + create_sockets, +) +from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, @@ -57,16 +67,6 @@ instance_name_from_service_info, ) from .services.registry import ServiceRegistry -from .utils.name import service_type_name -from .utils.net import ( - IPVersion, - InterfaceChoice, - InterfacesType, - autodetect_ip_version, - can_send_to, - create_sockets, -) -from .utils.time import current_time_millis, millis_to_seconds class NotifyListener: diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 5e1c6067..aa2b4983 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -27,6 +27,9 @@ from ._exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log +from ._utils.net import _is_v6_address +from ._utils.struct import int2byte +from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CLASSES, _CLASS_MASK, @@ -49,9 +52,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from .utils.net import _is_v6_address -from .utils.struct import int2byte -from .utils.time import current_time_millis, millis_to_seconds if TYPE_CHECKING: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f3397fa9..83f2524a 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -25,6 +25,7 @@ from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from ._logger import log +from ._utils.time import current_time_millis from .const import ( _CLASS_IN, _DNS_OTHER_TTL, @@ -41,7 +42,6 @@ RecordUpdateListener, ) from .services.registry import ServiceRegistry -from .utils.time import current_time_millis if TYPE_CHECKING: diff --git a/zeroconf/utils/__init__.py b/zeroconf/_utils/__init__.py similarity index 100% rename from zeroconf/utils/__init__.py rename to zeroconf/_utils/__init__.py diff --git a/zeroconf/utils/aio.py b/zeroconf/_utils/aio.py similarity index 100% rename from zeroconf/utils/aio.py rename to zeroconf/_utils/aio.py diff --git a/zeroconf/utils/name.py b/zeroconf/_utils/name.py similarity index 100% rename from zeroconf/utils/name.py rename to zeroconf/_utils/name.py diff --git a/zeroconf/utils/net.py b/zeroconf/_utils/net.py similarity index 100% rename from zeroconf/utils/net.py rename to zeroconf/_utils/net.py diff --git a/zeroconf/utils/struct.py b/zeroconf/_utils/struct.py similarity index 100% rename from zeroconf/utils/struct.py rename to zeroconf/_utils/struct.py diff --git a/zeroconf/utils/time.py b/zeroconf/_utils/time.py similarity index 100% rename from zeroconf/utils/time.py rename to zeroconf/_utils/time.py diff --git a/zeroconf/aio.py b/zeroconf/aio.py index d8ed256e..d745957a 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -29,11 +29,11 @@ from ._core import NotifyListener, Zeroconf from ._dns import DNSOutgoing from ._exceptions import NonUniqueNameException +from ._utils.aio import wait_condition_or_timeout +from ._utils.net import IPVersion, InterfaceChoice, InterfacesType +from ._utils.time import current_time_millis, millis_to_seconds from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info -from .utils.aio import wait_condition_or_timeout -from .utils.net import IPVersion, InterfaceChoice, InterfacesType -from .utils.time import current_time_millis, millis_to_seconds def _get_best_available_queue() -> queue.Queue: diff --git a/zeroconf/services/__init__.py b/zeroconf/services/__init__.py index 4fc62c87..09aa4c73 100644 --- a/zeroconf/services/__init__.py +++ b/zeroconf/services/__init__.py @@ -29,6 +29,14 @@ from .._dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException +from .._utils.name import service_type_name +from .._utils.net import ( + IPVersion, + _encode_address, + _is_v6_address, +) +from .._utils.struct import int2byte +from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -48,14 +56,7 @@ _TYPE_SRV, _TYPE_TXT, ) -from ..utils.name import service_type_name -from ..utils.net import ( - IPVersion, - _encode_address, - _is_v6_address, -) -from ..utils.struct import int2byte -from ..utils.time import current_time_millis, millis_to_seconds + if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 diff --git a/zeroconf/services/types.py b/zeroconf/services/types.py index d6cc1e97..6b454e65 100644 --- a/zeroconf/services/types.py +++ b/zeroconf/services/types.py @@ -24,9 +24,9 @@ from typing import Optional, Set, Tuple, Union from .._core import Zeroconf +from .._utils.net import IPVersion, InterfaceChoice, InterfacesType from ..const import _SERVICE_TYPE_ENUMERATION_NAME from ..services import ServiceBrowser, ServiceListener -from ..utils.net import IPVersion, InterfaceChoice, InterfacesType class ZeroconfServiceTypes(ServiceListener): From 4a88066d66b2f2a00ebc388c5cda478c52cb9e6c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:49:00 -1000 Subject: [PATCH 268/608] Mark zeroconf.services as protected by renaming to zeroconf._services (#583) - The public API should only access zeroconf and zeroconf.aio as internals may be relocated between releases --- setup.py | 2 +- tests/test_aio.py | 2 +- tests/test_services.py | 6 +++--- zeroconf/__init__.py | 6 +++--- zeroconf/_core.py | 16 ++++++++-------- zeroconf/_handlers.py | 6 ++---- zeroconf/{services => _services}/__init__.py | 0 zeroconf/{services => _services}/registry.py | 18 +++++++++--------- zeroconf/{services => _services}/types.py | 2 +- zeroconf/aio.py | 2 +- 10 files changed, 29 insertions(+), 31 deletions(-) rename zeroconf/{services => _services}/__init__.py (100%) rename zeroconf/{services => _services}/registry.py (90%) rename zeroconf/{services => _services}/types.py (98%) diff --git a/setup.py b/setup.py index f6f8582b..0ad299fb 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ author='Paul Scott-Murphy, William McBrine, Jakub Stasiak', url='https://github.com/jstasiak/python-zeroconf', package_data={"zeroconf": ["py.typed"]}, - packages=["zeroconf", "zeroconf.services", "zeroconf._utils"], + packages=["zeroconf", "zeroconf._services", "zeroconf._utils"], platforms=['unix', 'linux', 'osx'], license='LGPL', zip_safe=False, diff --git a/tests/test_aio.py b/tests/test_aio.py index 5d04e019..2b222242 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -15,7 +15,7 @@ from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered -from zeroconf.services import ServiceInfo, ServiceListener +from zeroconf._services import ServiceInfo, ServiceListener from zeroconf._utils.time import current_time_millis diff --git a/tests/test_services.py b/tests/test_services.py index 090fc736..c78b9f9c 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- -""" Unit tests for zeroconf.services. """ +""" Unit tests for zeroconf._services. """ import logging import socket @@ -16,9 +16,9 @@ import zeroconf as r from zeroconf import const -import zeroconf.services as s +import zeroconf._services as s from zeroconf import Zeroconf -from zeroconf.services import ( +from zeroconf._services import ( ServiceBrowser, ServiceInfo, ServiceStateChange, diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e424118d..8277c32b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -46,7 +46,7 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) -from .services import ( # noqa # import needed for backwards compat +from ._services import ( # noqa # import needed for backwards compat instance_name_from_service_info, Signal, SignalRegistrationInterface, @@ -56,8 +56,8 @@ ServiceListener, ServiceStateChange, ) -from .services.registry import ServiceRegistry # noqa # import needed for backwards compat -from .services.types import ZeroconfServiceTypes # noqa # import needed for backwards compat +from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat +from ._services.types import ZeroconfServiceTypes # noqa # import needed for backwards compat from ._utils.name import service_type_name # noqa # import needed for backwards compat from ._utils.net import ( # noqa # import needed for backwards compat add_multicast_member, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a2547305..1b3a4c1f 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -33,6 +33,14 @@ from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log +from ._services import ( + RecordUpdateListener, + ServiceBrowser, + ServiceInfo, + ServiceListener, + instance_name_from_service_info, +) +from ._services.registry import ServiceRegistry from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -59,14 +67,6 @@ _TYPE_PTR, _UNREGISTER_TIME, ) -from .services import ( - RecordUpdateListener, - ServiceBrowser, - ServiceInfo, - ServiceListener, - instance_name_from_service_info, -) -from .services.registry import ServiceRegistry class NotifyListener: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 83f2524a..7e8e734b 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -25,6 +25,8 @@ from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from ._logger import log +from ._services import RecordUpdateListener +from ._services.registry import ServiceRegistry from ._utils.time import current_time_millis from .const import ( _CLASS_IN, @@ -38,10 +40,6 @@ _TYPE_SRV, _TYPE_TXT, ) -from .services import ( - RecordUpdateListener, -) -from .services.registry import ServiceRegistry if TYPE_CHECKING: diff --git a/zeroconf/services/__init__.py b/zeroconf/_services/__init__.py similarity index 100% rename from zeroconf/services/__init__.py rename to zeroconf/_services/__init__.py diff --git a/zeroconf/services/registry.py b/zeroconf/_services/registry.py similarity index 90% rename from zeroconf/services/registry.py rename to zeroconf/_services/registry.py index b17c4284..6d1baa8e 100644 --- a/zeroconf/services/registry.py +++ b/zeroconf/_services/registry.py @@ -25,7 +25,7 @@ from .._exceptions import ServiceNameAlreadyRegistered -from ..services import ServiceInfo +from .._services import ServiceInfo class ServiceRegistry: @@ -40,7 +40,7 @@ def __init__( self, ) -> None: """Create the ServiceRegistry class.""" - self.services = {} # type: Dict[str, ServiceInfo] + self._services = {} # type: Dict[str, ServiceInfo] self.types = {} # type: Dict[str, List] self.servers = {} # type: Dict[str, List] self._lock = threading.Lock() # add and remove services thread safe @@ -66,11 +66,11 @@ def update(self, info: ServiceInfo) -> None: def get_service_infos(self) -> List[ServiceInfo]: """Return all ServiceInfo.""" - return list(self.services.values()) + return list(self._services.values()) def get_info_name(self, name: str) -> Optional[ServiceInfo]: """Return all ServiceInfo for the name.""" - return self.services.get(name) + return self._services.get(name) def get_types(self) -> List[str]: """Return all types.""" @@ -89,7 +89,7 @@ def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: service_infos = [] for name in getattr(self, attr).get(key, [])[:]: - info = self.services.get(name) + info = self._services.get(name) # Since we do not get under a lock since it would be # a performance issue, its possible # the service can be unregistered during the get @@ -102,17 +102,17 @@ def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" lower_name = info.name.lower() - if lower_name in self.services: + if lower_name in self._services: raise ServiceNameAlreadyRegistered - self.services[lower_name] = info + self._services[lower_name] = info self.types.setdefault(info.type, []).append(lower_name) self.servers.setdefault(info.server, []).append(lower_name) def _remove(self, info: ServiceInfo) -> None: """Remove a service under the lock.""" lower_name = info.name.lower() - old_service_info = self.services[lower_name] + old_service_info = self._services[lower_name] self.types[old_service_info.type].remove(lower_name) self.servers[old_service_info.server].remove(lower_name) - del self.services[lower_name] + del self._services[lower_name] diff --git a/zeroconf/services/types.py b/zeroconf/_services/types.py similarity index 98% rename from zeroconf/services/types.py rename to zeroconf/_services/types.py index 6b454e65..f611fc4c 100644 --- a/zeroconf/services/types.py +++ b/zeroconf/_services/types.py @@ -24,9 +24,9 @@ from typing import Optional, Set, Tuple, Union from .._core import Zeroconf +from .._services import ServiceBrowser, ServiceListener from .._utils.net import IPVersion, InterfaceChoice, InterfacesType from ..const import _SERVICE_TYPE_ENUMERATION_NAME -from ..services import ServiceBrowser, ServiceListener class ZeroconfServiceTypes(ServiceListener): diff --git a/zeroconf/aio.py b/zeroconf/aio.py index d745957a..3df58eae 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -29,11 +29,11 @@ from ._core import NotifyListener, Zeroconf from ._dns import DNSOutgoing from ._exceptions import NonUniqueNameException +from ._services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType from ._utils.time import current_time_millis, millis_to_seconds from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME -from .services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info def _get_best_available_queue() -> queue.Queue: From 1fe282ba246505d172356cc8672307c7d125820d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 09:58:38 -1000 Subject: [PATCH 269/608] Relocate ServiceTypesQuery tests to tests/services/test_types (#584) --- tests/services/__init__.py | 21 +++++ tests/services/test_types.py | 143 +++++++++++++++++++++++++++++++++++ tests/test_init.py | 135 +-------------------------------- 3 files changed, 166 insertions(+), 133 deletions(-) create mode 100644 tests/services/__init__.py create mode 100644 tests/services/test_types.py diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 00000000..2ef4b15b --- /dev/null +++ b/tests/services/__init__.py @@ -0,0 +1,21 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" diff --git a/tests/services/test_types.py b/tests/services/test_types.py new file mode 100644 index 00000000..845e20f8 --- /dev/null +++ b/tests/services/test_types.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._services.types.""" + +import os +import unittest +import socket + +import zeroconf as r +from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes + +from .. import _clear_cache, has_working_ipv6 + + +class ServiceTypesQuery(unittest.TestCase): + def test_integration_with_listener(self): + + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.register_service(info) + + try: + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_integration_with_listener_v6_records(self): + + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_pton(socket.AF_INET6, addr)], + ) + zeroconf_registrar.register_service(info) + + try: + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') + def test_integration_with_listener_ipv6(self): + + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.register_service(info) + + try: + service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types + + finally: + zeroconf_registrar.close() + + def test_integration_with_subtype_and_listener(self): + subtype_ = "_subtype._sub" + type_ = "_type._tcp.local." + name = "xxxyyy" + # Note: discovery returns only DNS-SD type not subtype + discovery_type = "%s.%s" % (subtype_, type_) + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + discovery_type, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.register_service(info) + + try: + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert discovery_type in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert discovery_type in service_types + + finally: + zeroconf_registrar.close() diff --git a/tests/test_init.py b/tests/test_init.py index ce424857..ab45b75f 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -6,20 +6,18 @@ import errno import logging -import os import socket import time import unittest import unittest.mock -from threading import Event from typing import Dict, Optional # noqa # used in type hints import pytest import zeroconf as r -from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, ZeroconfServiceTypes, const +from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, const -from . import has_working_ipv6, _clear_cache, _inject_response +from . import _clear_cache log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -459,135 +457,6 @@ def test_lookups(self): assert registry.get_types() == [type_] -class ServiceTypesQuery(unittest.TestCase): - def test_integration_with_listener(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types - - finally: - zeroconf_registrar.close() - - @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') - @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') - def test_integration_with_listener_v6_records(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_pton(socket.AF_INET6, addr)], - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types - - finally: - zeroconf_registrar.close() - - @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') - @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') - def test_integration_with_listener_ipv6(self): - - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types - - finally: - zeroconf_registrar.close() - - def test_integration_with_subtype_and_listener(self): - subtype_ = "_subtype._sub" - type_ = "_type._tcp.local." - name = "xxxyyy" - # Note: discovery returns only DNS-SD type not subtype - discovery_type = "%s.%s" % (subtype_, type_) - registration_name = "%s.%s" % (name, type_) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - discovery_type, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - zeroconf_registrar.register_service(info) - - try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert discovery_type in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert discovery_type in service_types - - finally: - zeroconf_registrar.close() - - def test_ptr_optimization(): # instantiate a zeroconf instance From 12f567695b5364c9c5c5af0a7017d877de84274d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:04:46 -1000 Subject: [PATCH 270/608] Relocate network utils tests to tests/utils/test_net (#585) --- tests/test_init.py | 24 ------------------------ tests/utils/test_net.py | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index ab45b75f..9dffe12f 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -4,7 +4,6 @@ """ Unit tests for zeroconf.py """ -import errno import logging import socket import time @@ -502,21 +501,6 @@ def test_ptr_optimization(): zc.close() -@pytest.mark.parametrize( - "errno,expected_result", - [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], -) -def test_add_multicast_member_socket_errors(errno, expected_result): - """Test we handle socket errors when adding multicast members.""" - if errno: - setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno))) - else: - setsockopt_mock = unittest.mock.Mock() - fileno_mock = unittest.mock.PropertyMock(return_value=10) - socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) - assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result - - def test_notify_listeners(): """Test adding and removing notify listeners.""" # instantiate a zeroconf instance @@ -553,11 +537,3 @@ def on_service_state_change(zeroconf, service_type, state_change, name): assert not notify_called zc.close() - - -def test_autodetect_ip_version(): - """Tests for auto detecting IPVersion based on interface ips.""" - assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only - assert r.autodetect_ip_version([]) is r.IPVersion.V4Only - assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All - assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 1360f936..1a8beebe 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -5,10 +5,13 @@ """Unit tests for zeroconf._utils.net.""" from unittest.mock import Mock, patch +import errno import ifaddr import pytest +import unittest from zeroconf._utils import net as netutils +import zeroconf as r def _generate_mock_adapters(): @@ -56,3 +59,26 @@ def test_ip6_addresses_to_indexes(): interfaces = ['2001:db8::'] with patch("zeroconf._utils.net.ifaddr.get_adapters", return_value=_generate_mock_adapters()): assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] + + +@pytest.mark.parametrize( + "errno,expected_result", + [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], +) +def test_add_multicast_member_socket_errors(errno, expected_result): + """Test we handle socket errors when adding multicast members.""" + if errno: + setsockopt_mock = unittest.mock.Mock(side_effect=OSError(errno, "Error: {}".format(errno))) + else: + setsockopt_mock = unittest.mock.Mock() + fileno_mock = unittest.mock.PropertyMock(return_value=10) + socket_mock = unittest.mock.Mock(setsockopt=setsockopt_mock, fileno=fileno_mock) + assert r.add_multicast_member(socket_mock, "0.0.0.0") == expected_result + + +def test_autodetect_ip_version(): + """Tests for auto detecting IPVersion based on interface ips.""" + assert r.autodetect_ip_version(["1.3.4.5"]) is r.IPVersion.V4Only + assert r.autodetect_ip_version([]) is r.IPVersion.V4Only + assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All + assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only From 5cb5702fca2845e99b457e4427428497c3cd9b31 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:11:22 -1000 Subject: [PATCH 271/608] Disable flakey ServiceTypesQuery ipv6 win32 test (#586) --- tests/services/test_types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 845e20f8..e8e9911f 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -7,6 +7,7 @@ import os import unittest import socket +import sys import zeroconf as r from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes @@ -78,7 +79,7 @@ def test_integration_with_listener_v6_records(self): finally: zeroconf_registrar.close() - @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(not has_working_ipv6() or sys.platform == 'win32', 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_ipv6(self): From ae6530a59e2d8ddb9a7367243c29c5e00665a82f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:15:38 -1000 Subject: [PATCH 272/608] Relocate ServiceRegistry tests to tests/services/test_registry (#587) --- tests/services/test_registry.py | 48 +++++++++++++++++++++++++++++++++ tests/test_init.py | 37 ------------------------- 2 files changed, 48 insertions(+), 37 deletions(-) create mode 100644 tests/services/test_registry.py diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py new file mode 100644 index 00000000..74af1e65 --- /dev/null +++ b/tests/services/test_registry.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._services.registry.""" + +import unittest +import socket + +import zeroconf as r +from zeroconf import ServiceInfo + + +class TestServiceRegistry(unittest.TestCase): + def test_only_register_once(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) + registry.remove(info) + registry.add(info) + + def test_lookups(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + + assert registry.get_service_infos() == [info] + assert registry.get_info_name(registration_name) == info + assert registry.get_infos_type(type_) == [info] + assert registry.get_infos_server("ash-2.local.") == [info] + assert registry.get_types() == [type_] diff --git a/tests/test_init.py b/tests/test_init.py index 9dffe12f..33924ecb 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -419,43 +419,6 @@ def test_register_and_lookup_type_by_uppercase_name(self): zc.close() -class TestServiceRegistry(unittest.TestCase): - def test_only_register_once(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - - registry = r.ServiceRegistry() - registry.add(info) - self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) - registry.remove(info) - registry.add(info) - - def test_lookups(self): - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - - registry = r.ServiceRegistry() - registry.add(info) - - assert registry.get_service_infos() == [info] - assert registry.get_info_name(registration_name) == info - assert registry.get_infos_type(type_) == [info] - assert registry.get_infos_server("ash-2.local.") == [info] - assert registry.get_types() == [type_] - - def test_ptr_optimization(): # instantiate a zeroconf instance From 8aa14d33849c057c91a00e1093606081ade488e7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:29:10 -1000 Subject: [PATCH 273/608] Relocate handlers tests to tests/test_handlers (#588) --- tests/test_handlers.py | 246 +++++++++++++++++++++++++++++++++++++++++ tests/test_init.py | 218 +----------------------------------- 2 files changed, 247 insertions(+), 217 deletions(-) create mode 100644 tests/test_handlers.py diff --git a/tests/test_handlers.py b/tests/test_handlers.py new file mode 100644 index 00000000..53d9b9b7 --- /dev/null +++ b/tests/test_handlers.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._handlers """ + +import logging +import pytest +import socket +import time +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import ServiceInfo, Zeroconf +from zeroconf import const + +from . import _clear_cache + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestRegistrar(unittest.TestCase): + def test_ttl(self): + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + nbr_answers = nbr_additionals = nbr_authorities = 0 + + def get_ttl(record_type): + if expected_ttl is not None: + return expected_ttl + elif record_type in [const._TYPE_A, const._TYPE_SRV]: + return const._DNS_HOST_TTL + else: + return const._DNS_OTHER_TTL + + def _process_outgoing_packet(out): + """Sends an outgoing packet.""" + nonlocal nbr_answers, nbr_additionals, nbr_authorities + + for answer, time_ in out.answers: + nbr_answers += 1 + assert answer.ttl == get_ttl(answer.type) + for answer in out.additionals: + nbr_additionals += 1 + assert answer.ttl == get_ttl(answer.type) + for answer in out.authorities: + nbr_authorities += 1 + assert answer.ttl == get_ttl(answer.type) + + # register service with default TTL + expected_ttl = None + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.add(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, None)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # query + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + assert query.is_query() is True + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) + assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # unregister + expected_ttl = 0 + zc.registry.remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + expected_ttl = None + for _ in range(3): + _process_outgoing_packet(zc.generate_service_query(info)) + zc.registry.add(info) + # register service with custom TTL + expected_ttl = const._DNS_HOST_TTL * 2 + assert expected_ttl != const._DNS_HOST_TTL + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # query + expected_ttl = None + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) + query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) + _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) + assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # unregister + expected_ttl = 0 + zc.registry.remove(info) + for _ in range(3): + _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) + assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 + nbr_answers = nbr_additionals = nbr_authorities = 0 + zc.close() + + def test_name_conflicts(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_homeassistant._tcp.local." + name = "Home" + registration_name = "%s.%s" % (name, type_) + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + + conflicting_info = ServiceInfo( + type_, + name=registration_name, + server="random456.local.", + addresses=[socket.inet_pton(socket.AF_INET, "4.5.6.7")], + port=80, + properties={"version": "1.0"}, + ) + with pytest.raises(r.NonUniqueNameException): + zc.register_service(conflicting_info) + zc.close() + + def test_register_and_lookup_type_by_uppercase_name(self): + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_mylowertype._tcp.local." + name = "Home" + registration_name = "%s.%s" % (name, type_) + + info = ServiceInfo( + type_, + name=registration_name, + server="random123.local.", + addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], + port=80, + properties={"version": "1.0"}, + ) + zc.register_service(info) + _clear_cache(zc) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [] + + out = r.DNSOutgoing(const._FLAGS_QR_QUERY) + out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) + zc.send(out) + time.sleep(0.5) + info = ServiceInfo(type_, registration_name) + info.load_from_cache(zc) + assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] + assert info.properties == {b"version": b"1.0"} + zc.close() + + +def test_ptr_optimization(): + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + nbr_answers = nbr_additionals = nbr_authorities = 0 + has_srv = has_txt = has_a = False + + # register + zc.register_service(info) + nbr_answers = nbr_additionals = nbr_authorities = 0 + + # query + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False) + assert out is not None + nbr_answers += len(out.answers) + nbr_authorities += len(out.authorities) + for answer in out.additionals: + nbr_additionals += 1 + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert has_srv and has_txt and has_a + + # unregister + zc.unregister_service(info) + zc.close() diff --git a/tests/test_init.py b/tests/test_init.py index 33924ecb..48fff674 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -9,15 +9,13 @@ import time import unittest import unittest.mock -from typing import Dict, Optional # noqa # used in type hints +from typing import Optional # noqa # used in type hints import pytest import zeroconf as r from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, const -from . import _clear_cache - log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -250,220 +248,6 @@ def generate_host(zc, host_name, type_): zc.send(out) -class TestRegistrar(unittest.TestCase): - def test_ttl(self): - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - # service definition - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - - nbr_answers = nbr_additionals = nbr_authorities = 0 - - def get_ttl(record_type): - if expected_ttl is not None: - return expected_ttl - elif record_type in [const._TYPE_A, const._TYPE_SRV]: - return const._DNS_HOST_TTL - else: - return const._DNS_OTHER_TTL - - def _process_outgoing_packet(out): - """Sends an outgoing packet.""" - nonlocal nbr_answers, nbr_additionals, nbr_authorities - - for answer, time_ in out.answers: - nbr_answers += 1 - assert answer.ttl == get_ttl(answer.type) - for answer in out.additionals: - nbr_additionals += 1 - assert answer.ttl == get_ttl(answer.type) - for answer in out.authorities: - nbr_authorities += 1 - assert answer.ttl == get_ttl(answer.type) - - # register service with default TTL - expected_ttl = None - for _ in range(3): - _process_outgoing_packet(zc.generate_service_query(info)) - zc.registry.add(info) - for _ in range(3): - _process_outgoing_packet(zc.generate_service_broadcast(info, None)) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) - assert query.is_query() is True - query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # unregister - expected_ttl = 0 - zc.registry.remove(info) - for _ in range(3): - _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - expected_ttl = None - for _ in range(3): - _process_outgoing_packet(zc.generate_service_query(info)) - zc.registry.add(info) - # register service with custom TTL - expected_ttl = const._DNS_HOST_TTL * 2 - assert expected_ttl != const._DNS_HOST_TTL - for _ in range(3): - _process_outgoing_packet(zc.generate_service_broadcast(info, expected_ttl)) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - expected_ttl = None - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) - query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # unregister - expected_ttl = 0 - zc.registry.remove(info) - for _ in range(3): - _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) - assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 - nbr_answers = nbr_additionals = nbr_authorities = 0 - zc.close() - - def test_name_conflicts(self): - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_homeassistant._tcp.local." - name = "Home" - registration_name = "%s.%s" % (name, type_) - - info = ServiceInfo( - type_, - name=registration_name, - server="random123.local.", - addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], - port=80, - properties={"version": "1.0"}, - ) - zc.register_service(info) - - conflicting_info = ServiceInfo( - type_, - name=registration_name, - server="random456.local.", - addresses=[socket.inet_pton(socket.AF_INET, "4.5.6.7")], - port=80, - properties={"version": "1.0"}, - ) - with pytest.raises(r.NonUniqueNameException): - zc.register_service(conflicting_info) - zc.close() - - def test_register_and_lookup_type_by_uppercase_name(self): - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_mylowertype._tcp.local." - name = "Home" - registration_name = "%s.%s" % (name, type_) - - info = ServiceInfo( - type_, - name=registration_name, - server="random123.local.", - addresses=[socket.inet_pton(socket.AF_INET, "1.2.3.4")], - port=80, - properties={"version": "1.0"}, - ) - zc.register_service(info) - _clear_cache(zc) - info = ServiceInfo(type_, registration_name) - info.load_from_cache(zc) - assert info.addresses == [] - - out = r.DNSOutgoing(const._FLAGS_QR_QUERY) - out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) - zc.send(out) - time.sleep(0.5) - info = ServiceInfo(type_, registration_name) - info.load_from_cache(zc) - assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] - assert info.properties == {b"version": b"1.0"} - zc.close() - - -def test_ptr_optimization(): - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - # service definition - type_ = "_test-srvc-type._tcp.local." - name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - - nbr_answers = nbr_additionals = nbr_authorities = 0 - has_srv = has_txt = has_a = False - - # register - zc.register_service(info) - nbr_answers = nbr_additionals = nbr_authorities = 0 - - # query - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) - query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False) - assert out is not None - nbr_answers += len(out.answers) - nbr_authorities += len(out.authorities) - for answer in out.additionals: - nbr_additionals += 1 - if answer.type == const._TYPE_SRV: - has_srv = True - elif answer.type == const._TYPE_TXT: - has_txt = True - elif answer.type == const._TYPE_A: - has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 - assert has_srv and has_txt and has_a - - # unregister - zc.unregister_service(info) - zc.close() - - def test_notify_listeners(): """Test adding and removing notify listeners.""" # instantiate a zeroconf instance From fd70ac1b6bdded992f8fbbb723ca92f5395abf23 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:32:44 -1000 Subject: [PATCH 274/608] Set mypy follow_imports to skip as ignore is not a valid option (#590) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a9dddb26..e9dc052f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,7 @@ ignore=E203,W503 [mypy] ignore_missing_imports = true -follow_imports = ignore +follow_imports = skip check_untyped_defs = true no_implicit_optional = true warn_incomplete_stub = true From 72032d6dde2ee7388b8cb4545554519d3ffa8508 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 10:58:03 -1000 Subject: [PATCH 275/608] Move notify listener tests to test_core (#591) --- tests/test_core.py | 44 +++++++++++++++++++++++++++++++++++++++++--- tests/test_init.py | 40 ---------------------------------------- 2 files changed, 41 insertions(+), 43 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index a99519cb..6d7467ab 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,11 +2,12 @@ # -*- coding: utf-8 -*- -""" Unit tests for zeroconf.core """ +""" Unit tests for zeroconf._core """ import itertools import logging import os +import pytest import socket import time import unittest @@ -14,8 +15,7 @@ from typing import cast import zeroconf as r -from zeroconf import _core -from zeroconf import const +from zeroconf import _core, const, ServiceBrowser, Zeroconf from . import has_working_ipv6, _inject_response @@ -234,3 +234,41 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS finally: zeroconf.close() + + +def test_notify_listeners(): + """Test adding and removing notify listeners.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + notify_called = 0 + + class TestNotifyListener(r.NotifyListener): + def notify_all(self): + nonlocal notify_called + notify_called += 1 + + with pytest.raises(NotImplementedError): + r.NotifyListener().notify_all() + + notify_listener = TestNotifyListener() + + zc.add_notify_listener(notify_listener) + + def on_service_state_change(zeroconf, service_type, state_change, name): + """Dummy service callback.""" + + # start a browser + browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) + browser.cancel() + + assert notify_called + zc.remove_notify_listener(notify_listener) + + notify_called = 0 + # start a browser + browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) + browser.cancel() + + assert not notify_called + + zc.close() diff --git a/tests/test_init.py b/tests/test_init.py index 48fff674..4710a994 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -11,8 +11,6 @@ import unittest.mock from typing import Optional # noqa # used in type hints -import pytest - import zeroconf as r from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, const @@ -246,41 +244,3 @@ def generate_host(zc, host_name, type_): 0, ) zc.send(out) - - -def test_notify_listeners(): - """Test adding and removing notify listeners.""" - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - notify_called = 0 - - class TestNotifyListener(r.NotifyListener): - def notify_all(self): - nonlocal notify_called - notify_called += 1 - - with pytest.raises(NotImplementedError): - r.NotifyListener().notify_all() - - notify_listener = TestNotifyListener() - - zc.add_notify_listener(notify_listener) - - def on_service_state_change(zeroconf, service_type, state_change, name): - """Dummy service callback.""" - - # start a browser - browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) - browser.cancel() - - assert notify_called - zc.remove_notify_listener(notify_listener) - - notify_called = 0 - # start a browser - browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) - browser.cancel() - - assert not notify_called - - zc.close() From 35e25fd46f8d3689b723dd845eba9862a5dc8a22 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 11:36:30 -1000 Subject: [PATCH 276/608] Reduce branching in DNSOutgoing.add_answer_at_time (#592) --- tests/test_dns.py | 44 +++++++++++++++++++++++++++++++++++++++++++- zeroconf/_dns.py | 5 ++--- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index a3c50ee2..f24f2212 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -14,7 +14,7 @@ from typing import Dict, cast # noqa # used in type hints import zeroconf as r -from zeroconf import const +from zeroconf import const, current_time_millis from zeroconf import ( DNSHinfo, DNSText, @@ -175,6 +175,48 @@ def test_parse_own_packet_response(self): assert len(generated.answers) == 1 assert len(generated.answers) == len(parsed.answers) + def test_adding_empty_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + None, + 0, + ) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_adding_expired_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + current_time_millis() + 1000000, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 0 + assert len(generated.answers) == len(parsed.answers) + def test_match_question(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index aa2b4983..b8f571ed 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -579,9 +579,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: """Adds an answer if it does not expire by a certain time""" - if record is not None: - if now == 0 or not record.is_expired(now): - self.answers.append((record, now)) + if record is not None and (now == 0 or not record.is_expired(now)): + self.answers.append((record, now)) def add_authorative_answer(self, record: DNSPointer) -> None: """Adds an authoritative answer""" From d2d826220bd4f287835ebb4304450cc2311d1db6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 12:57:15 -1000 Subject: [PATCH 277/608] Add unicast property to DNSQuestion to determine if the QU bit is set (#593) --- tests/test_dns.py | 22 +++++++++++++++++++++- zeroconf/_dns.py | 35 +++++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index f24f2212..99a6b172 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -14,7 +14,7 @@ from typing import Dict, cast # noqa # used in type hints import zeroconf as r -from zeroconf import const, current_time_millis +from zeroconf import DNSIncoming, const, current_time_millis from zeroconf import ( DNSHinfo, DNSText, @@ -747,3 +747,23 @@ def test_tc_bit_not_set_in_answer_packet(): third_packet = r.DNSIncoming(packets[2]) assert third_packet.flags & const._FLAGS_TC == 0 assert third_packet.valid is True + + +# 4003 15.973052 192.168.107.68 224.0.0.251 MDNS 76 Standard query 0xffc4 PTR _raop._tcp.local, "QM" question +def test_qm_packet_parser(): + """Test we can parse a query packet with the QM bit.""" + qm_packet = ( + b'\xff\xc4\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x05_raop\x04_tcp\x05local\x00\x00\x0c\x00\x01' + ) + parsed = DNSIncoming(qm_packet) + assert parsed.questions[0].unicast is False + assert ",QM," in str(parsed.questions[0]) + + +# 389951 1450.577370 192.168.107.111 224.0.0.251 MDNS 115 Standard query 0x0000 PTR _companion-link._tcp.local, "QU" question OPT +def test_qu_packet_parser(): + """Test we can parse a query packet with the QU bit.""" + qu_packet = b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x0f_companion-link\x04_tcp\x05local\x00\x00\x0c\x80\x01\x00\x00)\x05\xa0\x00\x00\x11\x94\x00\x12\x00\x04\x00\x0e\x00dz{\x8a6\x9czF\x84,\xcaQ\xff' + parsed = DNSIncoming(qu_packet) + assert parsed.questions[0].unicast is True + assert ",QU," in str(parsed.questions[0]) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index b8f571ed..54225808 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -91,17 +91,14 @@ def get_type(t: int) -> str: def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: """String representation with additional information""" - result = "%s[%s,%s" % (hdr, self.get_type(self.type), self.get_class_(self.class_)) - if self.unique: - result += "-unique," - else: - result += "," - result += self.name - if other is not None: - result += "]=%s" % cast(Any, other) - else: - result += "]" - return result + return "%s[%s,%s%s,%s]%s" % ( + hdr, + self.get_type(self.type), + self.get_class_(self.class_), + "-unique" if self.unique else "", + self.name, + "=%s" % cast(Any, other) if other is not None else "", + ) class DNSQuestion(DNSEntry): @@ -119,9 +116,23 @@ def answered_by(self, rec: 'DNSRecord') -> bool: and self.name == rec.name ) + @property + def unicast(self) -> bool: + """Returns true if the QU (not QM) is set. + + unique shares the same mask as the one + used for unicast. + """ + return self.unique + def __repr__(self) -> str: """String representation""" - return DNSEntry.entry_to_string(self, "question", None) + return "%s[question,%s,%s,%s]" % ( + self.get_type(self.type), + "QU" if self.unicast else "QM", + self.get_class_(self.class_), + self.name, + ) class DNSRecord(DNSEntry): From fe72524dbaf934ca63ebce053e34f3e838743460 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 14:07:09 -1000 Subject: [PATCH 278/608] Fix lookup of uppercase names in registry (#597) - If the ServiceInfo was registered with an uppercase name and the query was for a lowercase name, it would not be found and vice-versa. --- tests/services/test_registry.py | 38 +++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 7 +++--- zeroconf/_services/registry.py | 12 +++++------ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 74af1e65..52726a04 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -46,3 +46,41 @@ def test_lookups(self): assert registry.get_infos_type(type_) == [info] assert registry.get_infos_server("ash-2.local.") == [info] assert registry.get_types() == [type_] + + def test_lookups_upper_case_by_lower_case(self): + type_ = "_test-SRVC-type._tcp.local." + name = "Xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ASH-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + + assert registry.get_service_infos() == [info] + assert registry.get_info_name(registration_name.lower()) == info + assert registry.get_infos_type(type_.lower()) == [info] + assert registry.get_infos_server("ash-2.local.") == [info] + assert registry.get_types() == [type_.lower()] + + def test_lookups_lower_case_by_upper_case(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + + assert registry.get_service_infos() == [info] + assert registry.get_info_name(registration_name.upper()) == info + assert registry.get_infos_type(type_.upper()) == [info] + assert registry.get_infos_server("ASH-2.local.") == [info] + assert registry.get_types() == [type_] diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 7e8e734b..8d824ec5 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -73,7 +73,7 @@ def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgo def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name.lower()): + for service in self.registry.get_infos_type(question.name): out.add_answer(msg, service.dns_pointer()) # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. @@ -87,14 +87,13 @@ def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DN Add answer(s) for A, AAAA, SRV, or TXT queries. """ - name_to_find = question.name.lower() # Answer A record queries for any service addresses we know if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(name_to_find): + for service in self.registry.get_infos_server(question.name): for dns_address in service.dns_addresses(): out.add_answer(msg, dns_address) - service = self.registry.get_info_name(name_to_find) # type: ignore + service = self.registry.get_info_name(question.name) # type: ignore if service is None: return diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 6d1baa8e..4c4c1706 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -70,7 +70,7 @@ def get_service_infos(self) -> List[ServiceInfo]: def get_info_name(self, name: str) -> Optional[ServiceInfo]: """Return all ServiceInfo for the name.""" - return self._services.get(name) + return self._services.get(name.lower()) def get_types(self) -> List[str]: """Return all types.""" @@ -88,7 +88,7 @@ def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" service_infos = [] - for name in getattr(self, attr).get(key, [])[:]: + for name in getattr(self, attr).get(key.lower(), [])[:]: info = self._services.get(name) # Since we do not get under a lock since it would be # a performance issue, its possible @@ -106,13 +106,13 @@ def _add(self, info: ServiceInfo) -> None: raise ServiceNameAlreadyRegistered self._services[lower_name] = info - self.types.setdefault(info.type, []).append(lower_name) - self.servers.setdefault(info.server, []).append(lower_name) + self.types.setdefault(info.type.lower(), []).append(lower_name) + self.servers.setdefault(info.server.lower(), []).append(lower_name) def _remove(self, info: ServiceInfo) -> None: """Remove a service under the lock.""" lower_name = info.name.lower() old_service_info = self._services[lower_name] - self.types[old_service_info.type].remove(lower_name) - self.servers[old_service_info.server].remove(lower_name) + self.types[old_service_info.type.lower()].remove(lower_name) + self.servers[old_service_info.server.lower()].remove(lower_name) del self._services[lower_name] From cb64e0dd5d1c621f61d0d0f92ea282d287a9c242 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 15:06:28 -1000 Subject: [PATCH 279/608] Add id_ param to allow setting the id in the DNSOutgoing constructor (#599) --- tests/test_dns.py | 5 +++++ zeroconf/_dns.py | 22 +++++++++++----------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 99a6b172..0e51aa4c 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -427,6 +427,11 @@ def test_transaction_id(self): id = bytes[0] << 8 | bytes[1] assert id == 0 + def test_setting_id(self): + """Test setting id in the constructor""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY, id_=4444) + assert generated.id == 4444 + def test_query_header_bits(self): generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) bytes = generated.packets()[0] diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 54225808..91bb14dc 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -537,25 +537,25 @@ class DNSOutgoing(DNSMessage): """Object representation of an outgoing packet""" - def __init__(self, flags: int, multicast: bool = True) -> None: + def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: super().__init__(flags) self.finished = False - self.id = 0 + self.id = id_ self.multicast = multicast - self.packets_data = [] # type: List[bytes] + self.packets_data: List[bytes] = [] # these 3 are per-packet -- see also reset_for_next_packet() - self.names = {} # type: Dict[str, int] - self.data = [] # type: List[bytes] - self.size = 12 - self.allow_long = True + self.names: Dict[str, int] = {} + self.data: List[bytes] = [] + self.size: int = 12 + self.allow_long: bool = True self.state = self.State.init - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[Tuple[DNSRecord, float]] - self.authorities = [] # type: List[DNSPointer] - self.additionals = [] # type: List[DNSRecord] + self.questions: List[DNSQuestion] = [] + self.answers: List[Tuple[DNSRecord, float]] = [] + self.authorities: List[DNSPointer] = [] + self.additionals: List[DNSRecord] = [] def reset_for_next_packet(self) -> None: self.names = {} From 3556c22aacc72e62c318955c084533b70311bcc9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 15:22:51 -1000 Subject: [PATCH 280/608] Ensure unicast responses can be sent to any source port (#598) - Unicast responses were only being sent if the source port was 53, this prevented responses when testing with dig: dig -p 5353 @224.0.0.251 media-12.local The above query will now see a response --- tests/test_handlers.py | 70 +++++++++++++++++++++++++++++++++++------- zeroconf/_core.py | 25 ++++++--------- zeroconf/_handlers.py | 43 ++++++++++++++------------ 3 files changed, 91 insertions(+), 47 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 53d9b9b7..ea7ab589 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -96,7 +96,9 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) + _process_outgoing_packet( + zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] + ) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -127,7 +129,9 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet(zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False)) + _process_outgoing_packet( + zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] + ) assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -216,21 +220,23 @@ def test_ptr_optimization(): type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) - nbr_answers = nbr_additionals = nbr_authorities = 0 - has_srv = has_txt = has_a = False - # register zc.register_service(info) - nbr_answers = nbr_additionals = nbr_authorities = 0 # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), False) - assert out is not None - nbr_answers += len(out.answers) - nbr_authorities += len(out.authorities) - for answer in out.additionals: + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT + ) + assert multicast_out.id == query.id + assert unicast_out is None + assert multicast_out is not None + has_srv = has_txt = has_a = False + nbr_additionals = 0 + nbr_answers = len(multicast_out.answers) + nbr_authorities = len(multicast_out.authorities) + for answer in multicast_out.additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: has_srv = True @@ -244,3 +250,45 @@ def test_ptr_optimization(): # unregister zc.unregister_service(info) zc.close() + + +def test_unicast_response(): + """Ensure we send a unicast response when the source port is not the MDNS port.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + # register + zc.register_service(info) + + # query + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + unicast_out, multicast_out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), "1.2.3.4", 1234) + for out in (unicast_out, multicast_out): + assert out.id == query.id + has_srv = has_txt = has_a = False + nbr_additionals = 0 + nbr_answers = len(multicast_out.answers) + nbr_authorities = len(multicast_out.authorities) + for answer in out.additionals: + nbr_additionals += 1 + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert has_srv and has_txt and has_a + + # unregister + zc.unregister_service(info) + zc.close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1b3a4c1f..783f81fa 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -55,7 +55,6 @@ _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, _CLASS_IN, - _DNS_PORT, _FLAGS_AA, _FLAGS_QR_QUERY, _FLAGS_QR_RESPONSE, @@ -209,19 +208,11 @@ def handle_read(self, socket_: socket.socket) -> None: if not msg.valid: pass - elif msg.is_query(): - # Always multicast responses - if port == _MDNS_PORT: - self.zc.handle_query(msg, None, _MDNS_PORT) - - # If it's not a multicast query, reply via unicast - # and multicast - elif port == _DNS_PORT: - self.zc.handle_query(msg, addr, port) - self.zc.handle_query(msg, None, _MDNS_PORT) - - else: + elif not msg.is_query(): self.zc.handle_response(msg) + return + + self.zc.handle_query(msg, addr, port) class Zeroconf(QuietLogger): @@ -502,9 +493,11 @@ def handle_response(self, msg: DNSIncoming) -> None: def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: """Deal with incoming query packets. Provides a response if possible.""" - out = self.query_handler.response(msg, port != _MDNS_PORT) - if out: - self.send(out, addr, port) + unicast_out, multicast_out = self.query_handler.response(msg, addr, port) + if unicast_out and unicast_out.answers: + self.send(unicast_out, addr, port) + if multicast_out and multicast_out.answers: + self.send(multicast_out, None, _MDNS_PORT) def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: """Sends an outgoing packet.""" diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 8d824ec5..65eb472b 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,7 +21,7 @@ """ import itertools -from typing import List, Optional, TYPE_CHECKING, Union +from typing import List, Optional, TYPE_CHECKING, Tuple, Union from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from ._logger import log @@ -33,6 +33,7 @@ _DNS_OTHER_TTL, _FLAGS_AA, _FLAGS_QR_RESPONSE, + _MDNS_PORT, _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_A, _TYPE_ANY, @@ -105,30 +106,32 @@ def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DN for dns_address in service.dns_addresses(): out.add_additional_answer(dns_address) - def response(self, msg: DNSIncoming, unicast: bool) -> Optional[DNSOutgoing]: + def response( # pylint: disable=unused-argument + self, msg: DNSIncoming, addr: Optional[str], port: int + ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: """Deal with incoming query packets. Provides a response if possible.""" - if unicast: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False) + unicast_out = None + multicast_out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, id_=msg.id) + outputs = [multicast_out] + + if port != _MDNS_PORT: + unicast_out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=msg.id) + outputs.append(unicast_out) for question in msg.questions: - out.add_question(question) - else: - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._answer_service_type_enumeration_query(msg, out) - else: - self._answer_ptr_query(msg, out, question) - continue + unicast_out.add_question(question) - self._answer_non_ptr_query(msg, out, question) + for out in outputs: + for question in msg.questions: + if question.type == _TYPE_PTR: + if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._answer_service_type_enumeration_query(msg, out) + else: + self._answer_ptr_query(msg, out, question) + continue - if out is not None and out.answers: - out.id = msg.id - return out + self._answer_non_ptr_query(msg, out, question) - return None + return unicast_out, multicast_out class RecordManager: From f6cd8f6d23459f9ed48ad06ff6702e606d620eaf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 15:40:23 -1000 Subject: [PATCH 281/608] Add ZeroconfServiceTypes to zeroconf.__all__ (#601) - This class is in the readme, but is not exported by default --- zeroconf/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 8277c32b..02d2afa1 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -57,7 +57,7 @@ ServiceStateChange, ) from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat -from ._services.types import ZeroconfServiceTypes # noqa # import needed for backwards compat +from ._services.types import ZeroconfServiceTypes from ._utils.name import service_type_name # noqa # import needed for backwards compat from ._utils.net import ( # noqa # import needed for backwards compat add_multicast_member, @@ -89,6 +89,7 @@ "InterfaceChoice", "ServiceStateChange", "IPVersion", + "ZeroconfServiceTypes", ] if sys.version_info <= (3, 6): From 809b6df376205e6ab5ce8fb5fe3a92e77662fe2d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 15:40:41 -1000 Subject: [PATCH 282/608] Fix docs version to match readme (cpython 3.6+) (#602) --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index de5ba41a..8929f417 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,7 +16,7 @@ PyPI (installable, stable distributions): https://pypi.org/project/zeroconf. You pip install zeroconf -python-zeroconf works with CPython 3.5+ and PyPy 3 implementing Python 3.5+. +python-zeroconf works with CPython 3.6+ and PyPy 3 implementing Python 3.6+. Contents -------- From 850e2115aa79c10765dfc45a290a68193397de6c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 13 Jun 2021 23:58:34 -1000 Subject: [PATCH 283/608] Log destination when sending packets (#606) --- zeroconf/_core.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 783f81fa..3532318b 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -501,14 +501,19 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: """Sends an outgoing packet.""" - packets = out.packets() - packet_num = 0 - for packet in packets: - packet_num += 1 + for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return - log.debug('Sending (%d bytes #%d) %r as %r...', len(packet), packet_num, out, packet) + log.debug( + 'Sending to (%s, %d) (%d bytes #%d) %r as %r...', + addr, + port, + len(packet), + packet_num + 1, + out, + packet, + ) for s in self._respond_sockets: if self._GLOBAL_DONE: return From 22bd1475fb58c7c421c0009cd0c5c791cedb225d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 00:43:07 -1000 Subject: [PATCH 284/608] Ensure the QU bit is set for probe queries (#609) - The bit should be set per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 --- tests/test_core.py | 15 +++++++++ tests/test_dns.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++ zeroconf/_core.py | 11 ++++++- zeroconf/_dns.py | 14 +++++--- 4 files changed, 113 insertions(+), 6 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 6d7467ab..b8a5499b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -272,3 +272,18 @@ def on_service_state_change(zeroconf, service_type, state_change, name): assert not notify_called zc.close() + + +def test_generate_service_query_set_qu_bit(): + """Test generate_service_query sets the QU bit.""" + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + type_ = "._hap._tcp.local." + registration_name = "this-host-is-not-used._hap._tcp.local." + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + out = zeroconf_registrar.generate_service_query(info) + assert out.questions[0].unicast is True + zeroconf_registrar.close() diff --git a/tests/test_dns.py b/tests/test_dns.py index 0e51aa4c..4cd72046 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -326,6 +326,85 @@ def test_many_questions(self): parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 15 + def test_many_questions_with_many_known_answers(self): + """Test many questions and known answers get seperated into multiple packets.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + questions = [] + for _ in range(30): + question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + known_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_answer_at_time(known_answer, now) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 30 + assert len(parsed1.answers) == 88 + assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 101 + assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert len(parsed3.answers) == 11 + assert parsed3.flags & const._FLAGS_TC == 0 + + def test_massive_probe_packet_split(self): + """Test probe with many authorative answers.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + questions = [] + for _ in range(30): + question = r.DNSQuestion( + f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE + ) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + authorative_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_authorative_answer(authorative_answer) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert parsed1.questions[0].unicast is True + assert len(parsed1.questions) == 30 + assert parsed1.num_authorities == 88 + assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert parsed2.num_authorities == 101 + assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert parsed3.num_authorities == 11 + assert parsed3.flags & const._FLAGS_TC == 0 + def test_only_one_answer_can_by_large(self): """Test that only the first answer in each packet can be large. diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 3532318b..37985cf3 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -55,6 +55,7 @@ _CACHE_CLEANUP_INTERVAL, _CHECK_TIME, _CLASS_IN, + _CLASS_UNIQUE, _FLAGS_AA, _FLAGS_QR_QUERY, _FLAGS_QR_RESPONSE, @@ -399,7 +400,15 @@ def send_service_query(self, info: ServiceInfo) -> None: def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) - out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN)) + # https://datatracker.ietf.org/doc/html/rfc6762#section-8.1 + # Because of the mDNS multicast rate-limiting + # rules, the probes SHOULD be sent as "QU" questions with the unicast- + # response bit set, to allow a defending host to respond immediately + # via unicast, instead of potentially having to wait before replying + # via multicast. + # + # _CLASS_UNIQUE is the "QU" bit + out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) out.add_authorative_answer(info.dns_pointer()) return out diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 91bb14dc..dcc809a1 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -757,9 +757,16 @@ def write_question(self, question: DNSQuestion) -> bool: start_data_length, start_size = len(self.data), self.size self.write_name(question.name) self.write_short(question.type) - self.write_short(question.class_) + self.write_record_class(question) return self._check_data_limit_or_rollback(start_data_length, start_size) + def write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: + """Write out the record class including the unique/unicast (QU) bit.""" + if record.unique and self.multicast: + self.write_short(record.class_ | _CLASS_UNIQUE) + else: + self.write_short(record.class_) + def write_record(self, record: DNSRecord, now: float) -> bool: """Writes a record (answer, authoritative answer, additional) to the packet. Returns True on success, or False if we did not (either @@ -771,10 +778,7 @@ def write_record(self, record: DNSRecord, now: float) -> bool: start_data_length, start_size = len(self.data), self.size self.write_name(record.name) self.write_short(record.type) - if record.unique and self.multicast: - self.write_short(record.class_ | _CLASS_UNIQUE) - else: - self.write_short(record.class_) + self.write_record_class(record) if now == 0: self.write_int(record.ttl) else: From b7d867878153fa600053869265260992e5462b2d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 09:35:24 -1000 Subject: [PATCH 285/608] Make DNSRecords hashable (#611) - Allows storing them in a set for de-duplication - Needed to be able to check for duplicates to solve https://github.com/jstasiak/python-zeroconf/issues/604 --- tests/test_dns.py | 179 ++++++++++++++++++++++++++++++++++++++++++++++ zeroconf/_dns.py | 39 +++++++--- 2 files changed, 209 insertions(+), 9 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 4cd72046..3f46171c 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -851,3 +851,182 @@ def test_qu_packet_parser(): parsed = DNSIncoming(qu_packet) assert parsed.questions[0].unicast is True assert ",QU," in str(parsed.questions[0]) + + +def test_dns_record_hashablity_does_not_consider_ttl(): + """Test DNSRecord are hashable.""" + + # Verify the TTL is not considered in the hash + record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same') + record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') + + record_set = set([record1, record2]) + assert len(record_set) == 1 + + record_set.add(record1) + assert len(record_set) == 1 + + record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') + + record_set.add(record3_dupe) + assert len(record_set) == 1 + + +def test_dns_address_record_hashablity(): + """Test DNSAddress are hashable.""" + address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a') + address2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'b') + address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c') + address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c') + + record_set = set([address1, address2, address3, address4]) + assert len(record_set) == 4 + + record_set.add(address1) + assert len(record_set) == 4 + + address3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c') + + record_set.add(address3_dupe) + assert len(record_set) == 4 + + # Verify we can remove records + additional_set = set([address1, address2]) + record_set -= additional_set + assert record_set == set([address3, address4]) + + +def test_dns_hinfo_record_hashablity(): + """Test DNSHinfo are hashable.""" + hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os') + hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') + + record_set = set([hinfo1, hinfo2]) + assert len(record_set) == 2 + + record_set.add(hinfo1) + assert len(record_set) == 2 + + hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') + + record_set.add(hinfo2_dupe) + assert len(record_set) == 2 + + +def test_dns_pointer_record_hashablity(): + """Test DNSPointer are hashable.""" + ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') + ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') + + record_set = set([ptr1, ptr2]) + assert len(record_set) == 2 + + record_set.add(ptr1) + assert len(record_set) == 2 + + ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') + + record_set.add(ptr2_dupe) + assert len(record_set) == 2 + + +def test_dns_text_record_hashablity(): + """Test DNSText are hashable.""" + text1 = r.DNSText('irrelevant', 0, 0, 0, b'12345678901') + text2 = r.DNSText('irrelevant', 1, 0, 0, b'12345678901') + text3 = r.DNSText('irrelevant', 0, 1, 0, b'12345678901') + text4 = r.DNSText('irrelevant', 0, 0, 1, b'12345678901') + text5 = r.DNSText('irrelevant', 0, 0, 0, b'ABCDEFGHIJK') + + record_set = set([text1, text2, text3, text4, text5]) + assert len(record_set) == 5 + + record_set.add(text1) + assert len(record_set) == 5 + + text1_dupe = r.DNSText('irrelevant', 0, 0, 0, b'12345678901') + + record_set.add(text1_dupe) + assert len(record_set) == 5 + + +def test_dns_text_record_hashablity(): + """Test DNSText are hashable.""" + text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') + text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') + text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') + + record_set = set([text1, text2, text3, text4]) + + assert len(record_set) == 4 + + record_set.add(text1) + assert len(record_set) == 4 + + text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + + record_set.add(text1_dupe) + assert len(record_set) == 4 + + +def test_dns_text_record_hashablity(): + """Test DNSText are hashable.""" + text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') + text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') + text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') + + record_set = set([text1, text2, text3, text4]) + + assert len(record_set) == 4 + + record_set.add(text1) + assert len(record_set) == 4 + + text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + + record_set.add(text1_dupe) + assert len(record_set) == 4 + + +def test_dns_text_record_hashablity(): + """Test DNSText are hashable.""" + text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') + text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') + text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') + + record_set = set([text1, text2, text3, text4]) + + assert len(record_set) == 4 + + record_set.add(text1) + assert len(record_set) == 4 + + text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + + record_set.add(text1_dupe) + assert len(record_set) == 4 + + +def test_dns_service_record_hashablity(): + """Test DNSService are hashable.""" + srv1 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a') + srv2 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 1, 80, 'a') + srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a') + srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab') + + record_set = set([srv1, srv2, srv3, srv4]) + + assert len(record_set) == 4 + + record_set.add(srv1) + assert len(record_set) == 4 + + srv1_dupe = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a' + ) + + record_set.add(srv1_dupe) + assert len(record_set) == 4 diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index dcc809a1..8eb62593 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -70,6 +70,10 @@ def __init__(self, name: str, type_: int, class_: int) -> None: self.class_ = class_ & _CLASS_MASK self.unique = (class_ & _CLASS_UNIQUE) != 0 + def _entry_tuple(self) -> Tuple[str, int, int]: + """Entry Tuple for DNSEntry.""" + return (self.key, self.type, self.class_) + def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" return ( @@ -105,9 +109,6 @@ class DNSQuestion(DNSEntry): """A DNS question entry""" - def __init__(self, name: str, type_: int, class_: int) -> None: - DNSEntry.__init__(self, name, type_, class_) - def answered_by(self, rec: 'DNSRecord') -> bool: """Returns true if the question is answered by the record""" return ( @@ -141,7 +142,7 @@ class DNSRecord(DNSEntry): # TODO: Switch to just int ttl def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: - DNSEntry.__init__(self, name, type_, class_) + super().__init__(name, type_, class_) self.ttl = ttl self.created = current_time_millis() self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) @@ -205,7 +206,7 @@ class DNSAddress(DNSRecord): """A DNS address record""" def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl) self.address = address def write(self, out: 'DNSOutgoing') -> None: @@ -218,6 +219,10 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address ) + def __hash__(self) -> int: + """Hash to compare like DNSAddresses.""" + return hash((*self._entry_tuple(), self.address)) + def __repr__(self) -> str: """String representation""" try: @@ -235,7 +240,7 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl) self.cpu = cpu self.os = os @@ -253,6 +258,10 @@ def __eq__(self, other: Any) -> bool: and self.os == other.os ) + def __hash__(self) -> int: + """Hash to compare like DNSHinfo.""" + return hash((*self._entry_tuple(), self.cpu, self.os)) + def __repr__(self) -> str: """String representation""" return self.to_string(self.cpu + " " + self.os) @@ -263,7 +272,7 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl) self.alias = alias def write(self, out: 'DNSOutgoing') -> None: @@ -274,6 +283,10 @@ def __eq__(self, other: Any) -> bool: """Tests equality on alias""" return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) + def __hash__(self) -> int: + """Hash to compare like DNSPointer.""" + return hash((*self._entry_tuple(), self.alias)) + def __repr__(self) -> str: """String representation""" return self.to_string(self.alias) @@ -285,13 +298,17 @@ class DNSText(DNSRecord): def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: assert isinstance(text, (bytes, type(None))) - DNSRecord.__init__(self, name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl) self.text = text def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" out.write_string(self.text) + def __hash__(self) -> int: + """Hash to compare like DNSText.""" + return hash((*self._entry_tuple(), self.text)) + def __eq__(self, other: Any) -> bool: """Tests equality on text""" return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) @@ -318,7 +335,7 @@ def __init__( port: int, server: str, ) -> None: - DNSRecord.__init__(self, name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl) self.priority = priority self.weight = weight self.port = port @@ -342,6 +359,10 @@ def __eq__(self, other: Any) -> bool: and DNSEntry.__eq__(self, other) ) + def __hash__(self) -> int: + """Hash to compare like DNSService.""" + return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server)) + def __repr__(self) -> str: """String representation""" return self.to_string("%s:%s" % (self.server, self.port)) From aea2c8ab24d4be19b34f407c854241e0d73d0525 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 11:30:55 -1000 Subject: [PATCH 286/608] Add the ability for ServiceInfo.dns_addresses to filter by address type (#612) --- tests/test_services.py | 22 +++++++++++++++++++++- zeroconf/_services/__init__.py | 6 ++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index c78b9f9c..e2aa93f0 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -11,11 +11,12 @@ import os import unittest from threading import Event +from typing import List import pytest import zeroconf as r -from zeroconf import const +from zeroconf import DNSAddress, const import zeroconf._services as s from zeroconf import Zeroconf from zeroconf._services import ( @@ -1170,3 +1171,22 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.remove_listener(listener) zc.close() + + +def test_filter_address_by_type_from_service_info(): + """Verify dns_addresses can filter by ipversion.""" + desc = {'path': '/~paulsm/'} + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + registration_name = "%s.%s" % (name, type_) + ipv4 = socket.inet_aton("10.0.1.2") + ipv6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[ipv4, ipv6]) + + def dns_addresses_to_addresses(dns_address: List[DNSAddress]): + return [address.address for address in dns_address] + + assert dns_addresses_to_addresses(info.dns_addresses()) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 09aa4c73..f6092aae 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -655,7 +655,9 @@ def _process_record(self, record: DNSRecord, now: float) -> None: if record.key == self.key: self._set_text(record.text) - def dns_addresses(self, override_ttl: Optional[int] = None) -> List[DNSAddress]: + def dns_addresses( + self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All + ) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return [ DNSAddress( @@ -665,7 +667,7 @@ def dns_addresses(self, override_ttl: Optional[int] = None) -> List[DNSAddress]: override_ttl if override_ttl is not None else self.host_ttl, address, ) - for address in self._addresses + for address in self.addresses_by_version(version) ] def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: From 219aa3e54c944b2935c9a40cc15de19284aded3c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 12:07:28 -1000 Subject: [PATCH 287/608] Avoid including additionals when the answer is suppressed by known-answer supression (#614) --- tests/test_handlers.py | 112 ++++++++++++++++++++++++++++++++++++++++- zeroconf/_handlers.py | 11 +++- 2 files changed, 120 insertions(+), 3 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ea7ab589..379d5e8e 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -12,7 +12,7 @@ import unittest.mock import zeroconf as r -from zeroconf import ServiceInfo, Zeroconf +from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const from . import _clear_cache @@ -292,3 +292,113 @@ def test_unicast_response(): # unregister zc.unregister_service(info) zc.close() + + +def test_known_answer_supression(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownservice._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.register_service(info) + + now = current_time_millis() + + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is not None and multicast_out.answers + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_pointer(), now) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + # If the answer is suppressed, the additional should be suppresed as well + assert not multicast_out or not multicast_out.answers + + # Test A supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is not None and multicast_out.answers + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + for dns_address in info.dns_addresses(): + generated.add_answer_at_time(dns_address, now) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert not multicast_out or not multicast_out.answers + + # Test SRV supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is not None and multicast_out.answers + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_service(), now) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + # If the answer is suppressed, the additional should be suppresed as well + assert not multicast_out or not multicast_out.answers + + # Test TXT supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is not None and multicast_out.answers + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time(info.dns_text(), now) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert not multicast_out or not multicast_out.answers + + # unregister + zc.unregister_service(info) + zc.close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 65eb472b..dc29bbf6 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -75,6 +75,9 @@ def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgo def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: """Answer a PTR query.""" for service in self.registry.get_infos_type(question.name): + dns_pointer = service.dns_pointer() + if dns_pointer.suppressed_by(msg): + continue out.add_answer(msg, service.dns_pointer()) # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. @@ -103,8 +106,12 @@ def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DN if question.type in (_TYPE_TXT, _TYPE_ANY): out.add_answer(msg, service.dns_text()) if question.type == _TYPE_SRV: - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) + dns_service = service.dns_service() + if not dns_service.suppressed_by(msg): + # Add recommended additional answers according to + # https://datatracker.ietf.org/doc/html/rfc6763#section-12.2 + for dns_address in service.dns_addresses(): + out.add_additional_answer(dns_address) def response( # pylint: disable=unused-argument self, msg: DNSIncoming, addr: Optional[str], port: int From c828c7555ed1fb82ff95ed578262d1553f19d903 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 12:53:46 -1000 Subject: [PATCH 288/608] Breakout the query response handler into its own class (#615) --- tests/test_handlers.py | 83 +++++++++++++++- tests/test_services.py | 2 +- zeroconf/_handlers.py | 217 +++++++++++++++++++++++++++-------------- 3 files changed, 222 insertions(+), 80 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 379d5e8e..e8efefeb 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -96,10 +96,14 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet( - zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] - ) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 + multicast_out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[ + 1 + ] + _process_outgoing_packet(multicast_out) + + # The additonals should all be suppresed since they are all in the answers section + # + assert nbr_answers == 4 and nbr_additionals == 3 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -132,7 +136,7 @@ def _process_outgoing_packet(out): _process_outgoing_packet( zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] ) - assert nbr_answers == 4 and nbr_additionals == 4 and nbr_authorities == 0 + assert nbr_answers == 4 and nbr_additionals == 3 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -402,3 +406,72 @@ def test_known_answer_supression(): # unregister zc.unregister_service(info) zc.close() + + +def test_known_answer_supression_service_type_enumeration_query(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownservice._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.register_service(info) + + type_2 = "_knownservice2._tcp.local." + name = "knownname" + registration_name2 = "%s.%s" % (name, type_2) + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.register_service(info) + now = current_time_millis() + + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is not None and multicast_out.answers + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + generated.add_answer_at_time( + r.DNSPointer( + const._SERVICE_TYPE_ENUMERATION_NAME, + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + type_, + ), + now, + ) + generated.add_answer_at_time( + r.DNSPointer( + const._SERVICE_TYPE_ENUMERATION_NAME, + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + type_2, + ), + now, + ) + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert not multicast_out or not multicast_out.answers + + # unregister + zc.unregister_service(info) + zc.close() diff --git a/tests/test_services.py b/tests/test_services.py index e2aa93f0..6677bfcb 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -623,7 +623,7 @@ def update_service(self, zeroconf, type, name): assert info.properties[b'prop_true'] == b'1' assert info.properties[b'prop_false'] == b'0' assert info.addresses == addresses[:1] # no V6 by default - assert info.addresses_by_version(r.IPVersion.All) == addresses + assert set(info.addresses_by_version(r.IPVersion.All)) == set(addresses) cached_info = ServiceInfo(type_, registration_name) cached_info.load_from_cache(zeroconf_browser) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index dc29bbf6..4884977c 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -20,8 +20,9 @@ USA """ +import enum import itertools -from typing import List, Optional, TYPE_CHECKING, Tuple, Union +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from ._logger import log @@ -48,97 +49,165 @@ from ._core import Zeroconf # pylint: disable=cyclic-import +@enum.unique +class RecordSetKeys(enum.Enum): + Answers = 1 + Additionals = 2 + + +# Switch to a TypedDict once Python 3.8 is the minimum supported version +_RecordSetType = Dict[RecordSetKeys, Set[DNSRecord]] + + +class _QueryResponse: + """A pair for unicast and multicast DNSOutgoing responses.""" + + def __init__(self, msg: DNSIncoming, ucast_source: bool) -> None: + """Build a query response.""" + self._msg = msg + self._ucast_source = ucast_source + self._is_probe = msg.num_authorities > 0 + self._now = current_time_millis() + self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} + self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} + + def add_ucast_question_response(self, answers: Set[DNSRecord], additionals: Set[DNSRecord]) -> None: + """Generate a response to a unicast query.""" + self._ucast[RecordSetKeys.Answers].update(answers) + self._ucast[RecordSetKeys.Additionals].update(additionals) + + def add_mcast_question_response(self, answers: Set[DNSRecord], additionals: Set[DNSRecord]) -> None: + """Generate a response to a multicast query.""" + self._mcast[RecordSetKeys.Answers].update(answers) + self._mcast[RecordSetKeys.Additionals].update(additionals) + + def outgoing_unicast(self) -> Optional[DNSOutgoing]: + """Build the outgoing unicast response.""" + ucastout = self._construct_outgoing_from_record_set(self._ucast, False) + # Adding the questions back when the source is + # unicast (not MDNS port) is legacy behavior + # Is this correct? + if ucastout and self._ucast_source: + for question in self._msg.questions: + ucastout.add_question(question) + return ucastout + + def outgoing_multicast(self) -> Optional[DNSOutgoing]: + """Build the outgoing multicast response.""" + return self._construct_outgoing_from_record_set(self._mcast, True) + + def _construct_outgoing_from_record_set( + self, rrset: _RecordSetType, multicast: bool + ) -> Optional[DNSOutgoing]: + """Add answers and additionals to a DNSOutgoing.""" + if not rrset[RecordSetKeys.Answers] and not rrset[RecordSetKeys.Additionals]: + return None + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=multicast, id_=self._msg.id) + for answer in rrset[RecordSetKeys.Answers]: + out.add_answer_at_time(answer, 0) + for additional in rrset[RecordSetKeys.Additionals]: + out.add_additional_answer(additional) + return out + + class QueryHandler: """Query the ServiceRegistry.""" - def __init__(self, registry: ServiceRegistry): + def __init__(self, registry: ServiceRegistry) -> None: """Init the query handler.""" self.registry = registry - def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgoing) -> None: + def _answer_service_type_enumeration_query( + self, + msg: DNSIncoming, + ) -> Set[DNSRecord]: """Provide an answer to a service type enumeration query. https://datatracker.ietf.org/doc/html/rfc6763#section-9 """ - for stype in self.registry.get_types(): - out.add_answer( - msg, - DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, - _TYPE_PTR, - _CLASS_IN, - _DNS_OTHER_TTL, - stype, - ), - ) - - def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a PTR query.""" - for service in self.registry.get_infos_type(question.name): - dns_pointer = service.dns_pointer() - if dns_pointer.suppressed_by(msg): - continue - out.add_answer(msg, service.dns_pointer()) + records: Set[DNSRecord] = set( + DNSPointer(_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype) + for stype in self.registry.get_types() + ) + records -= set(dns_pointer for dns_pointer in records if dns_pointer.suppressed_by(msg)) + return records + + def _add_pointer_answers( + self, name: str, msg: DNSIncoming, answers: Set[DNSRecord], additionals: Set[DNSRecord] + ) -> None: + """Answer PTR/ANY question.""" + for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. - out.add_additional_answer(service.dns_service()) - out.add_additional_answer(service.dns_text()) + dns_pointer = service.dns_pointer() + if not dns_pointer.suppressed_by(msg): + answers.add(service.dns_pointer()) + additionals.add(service.dns_service()) + additionals.add(service.dns_text()) + additionals.update(service.dns_addresses()) + + def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord]) -> None: + """Answer address question.""" + for service in self.registry.get_infos_server(name): for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) - - def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None: - """Answer a query any query other then PTR. - - Add answer(s) for A, AAAA, SRV, or TXT queries. - """ - # Answer A record queries for any service addresses we know - if question.type in (_TYPE_A, _TYPE_ANY): - for service in self.registry.get_infos_server(question.name): - for dns_address in service.dns_addresses(): - out.add_answer(msg, dns_address) - - service = self.registry.get_info_name(question.name) # type: ignore - if service is None: - return - - if question.type in (_TYPE_SRV, _TYPE_ANY): - out.add_answer(msg, service.dns_service()) - if question.type in (_TYPE_TXT, _TYPE_ANY): - out.add_answer(msg, service.dns_text()) - if question.type == _TYPE_SRV: - dns_service = service.dns_service() - if not dns_service.suppressed_by(msg): - # Add recommended additional answers according to - # https://datatracker.ietf.org/doc/html/rfc6763#section-12.2 - for dns_address in service.dns_addresses(): - out.add_additional_answer(dns_address) + if not dns_address.suppressed_by(msg): + answers.add(dns_address) + + def _answer_question( + self, msg: DNSIncoming, question: DNSQuestion + ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: + answers: Set[DNSRecord] = set() + additionals: Set[DNSRecord] = set() + type_ = question.type + + if type_ == _TYPE_PTR: + self._add_pointer_answers(question.name, msg, answers, additionals) + + if type_ in (_TYPE_A, _TYPE_ANY): + self._add_address_answers(question.name, msg, answers) + + if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): + service = self.registry.get_info_name(question.name) # type: ignore + if service is not None: + if type_ in (_TYPE_SRV, _TYPE_ANY): + dns_service = service.dns_service() + if not dns_service.suppressed_by(msg): + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.2. + answers.add(service.dns_service()) + additionals.update(service.dns_addresses()) + if type_ in (_TYPE_TXT, _TYPE_ANY): + dns_text = service.dns_text() + if not dns_text.suppressed_by(msg): + answers.add(service.dns_text()) + + return answers, additionals + + def _answer_any_question( + self, msg: DNSIncoming, question: DNSQuestion + ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: + if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + empty_additionals: Set[DNSRecord] = set() + return self._answer_service_type_enumeration_query(msg), empty_additionals + + return self._answer_question(msg, question) def response( # pylint: disable=unused-argument self, msg: DNSIncoming, addr: Optional[str], port: int ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: """Deal with incoming query packets. Provides a response if possible.""" - unicast_out = None - multicast_out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, id_=msg.id) - outputs = [multicast_out] - - if port != _MDNS_PORT: - unicast_out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=msg.id) - outputs.append(unicast_out) - for question in msg.questions: - unicast_out.add_question(question) - - for out in outputs: - for question in msg.questions: - if question.type == _TYPE_PTR: - if question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._answer_service_type_enumeration_query(msg, out) - else: - self._answer_ptr_query(msg, out, question) - continue - - self._answer_non_ptr_query(msg, out, question) - - return unicast_out, multicast_out + ucast_source = port != _MDNS_PORT + query_res = _QueryResponse(msg, ucast_source) + + for question in msg.questions: + all_answers = self._answer_any_question(msg, question) + if ucast_source: + query_res.add_ucast_question_response(*all_answers) + # We always multicast as well even if its a unicast + # source as long as we haven't done it recently (75% of ttl) + query_res.add_mcast_question_response(*all_answers) + + return query_res.outgoing_unicast(), query_res.outgoing_multicast() class RecordManager: From 0100c08c5a3fb90d0795cf57f0bd3e11c7a94a0b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 13:10:20 -1000 Subject: [PATCH 289/608] Fix queries for AAAA records (#616) --- tests/test_handlers.py | 24 ++++++++++++++++++++++++ zeroconf/_handlers.py | 13 ++++++++----- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e8efefeb..e1eec79b 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -256,6 +256,30 @@ def test_ptr_optimization(): zc.close() +def test_aaaa_query(): + """Test that queries for AAAA records work.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownservice._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.register_service(info) + + _clear_cache(zc) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) + assert multicast_out.answers[0][0].address == ipv6_address + # unregister + zc.unregister_service(info) + zc.close() + + def test_unicast_response(): """Ensure we send a unicast response when the source port is not the MDNS port.""" # instantiate a zeroconf instance diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 4884977c..48d6529a 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -28,6 +28,7 @@ from ._logger import log from ._services import RecordUpdateListener from ._services.registry import ServiceRegistry +from ._utils.net import IPVersion from ._utils.time import current_time_millis from .const import ( _CLASS_IN, @@ -37,12 +38,14 @@ _MDNS_PORT, _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_A, + _TYPE_AAAA, _TYPE_ANY, _TYPE_PTR, _TYPE_SRV, _TYPE_TXT, ) +_TYPE_TO_IP_VERSION = {_TYPE_A: IPVersion.V4Only, _TYPE_AAAA: IPVersion.V6Only, _TYPE_ANY: IPVersion.All} if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 @@ -146,10 +149,10 @@ def _add_pointer_answers( additionals.add(service.dns_text()) additionals.update(service.dns_addresses()) - def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord]) -> None: - """Answer address question.""" + def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord], type_: int) -> None: + """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): - for dns_address in service.dns_addresses(): + for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): if not dns_address.suppressed_by(msg): answers.add(dns_address) @@ -163,8 +166,8 @@ def _answer_question( if type_ == _TYPE_PTR: self._add_pointer_answers(question.name, msg, answers, additionals) - if type_ in (_TYPE_A, _TYPE_ANY): - self._add_address_answers(question.name, msg, answers) + if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): + self._add_address_answers(question.name, msg, answers, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore From 427b7285269984cbb6f28c87a8bf8f864a5e15d7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 13:17:42 -1000 Subject: [PATCH 290/608] Suppress additionals when they are already in the answers section (#617) --- tests/test_handlers.py | 4 ++-- zeroconf/_handlers.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e1eec79b..5eef3a8d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -103,7 +103,7 @@ def _process_outgoing_packet(out): # The additonals should all be suppresed since they are all in the answers section # - assert nbr_answers == 4 and nbr_additionals == 3 and nbr_authorities == 0 + assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -136,7 +136,7 @@ def _process_outgoing_packet(out): _process_outgoing_packet( zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] ) - assert nbr_answers == 4 and nbr_additionals == 3 and nbr_authorities == 0 + assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 48d6529a..824e9ce5 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -105,6 +105,10 @@ def _construct_outgoing_from_record_set( """Add answers and additionals to a DNSOutgoing.""" if not rrset[RecordSetKeys.Answers] and not rrset[RecordSetKeys.Additionals]: return None + + # Suppress any additionals that are already in answers + rrset[RecordSetKeys.Additionals] -= rrset[RecordSetKeys.Answers] + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=multicast, id_=self._msg.id) for answer in rrset[RecordSetKeys.Answers]: out.add_answer_at_time(answer, 0) From b6365aa1f889a3045aa185f67354de622bd7ebd3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 13:30:44 -1000 Subject: [PATCH 291/608] Ensure matching PTR queries are returned with the ANY query (#618) Fixes #464 --- tests/test_handlers.py | 25 +++++++++++++++++++++++++ zeroconf/_handlers.py | 2 +- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 5eef3a8d..71a28259 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -256,6 +256,31 @@ def test_ptr_optimization(): zc.close() +def test_any_query_for_ptr(): + """Test that queries for ANY will return PTR records.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knownservice._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) + zc.register_service(info) + + _clear_cache(zc) + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) + assert multicast_out.answers[0][0].name == type_ + assert multicast_out.answers[0][0].alias == registration_name + # unregister + zc.unregister_service(info) + zc.close() + + def test_aaaa_query(): """Test that queries for AAAA records work.""" zc = Zeroconf(interfaces=['127.0.0.1']) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 824e9ce5..590a6c96 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -167,7 +167,7 @@ def _answer_question( additionals: Set[DNSRecord] = set() type_ = question.type - if type_ == _TYPE_PTR: + if type_ in (_TYPE_PTR, _TYPE_ANY): self._add_pointer_answers(question.name, msg, answers, additionals) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): From 0e644ad650627024c7a3f926a86f7d9ecc66e591 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 14:15:24 -1000 Subject: [PATCH 292/608] Protect the network against excessive packet flooding (#619) --- tests/test_handlers.py | 23 ++++++++++++++++++----- zeroconf/_core.py | 2 +- zeroconf/_handlers.py | 24 +++++++++++++++++++++--- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 71a28259..e0eeb353 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -15,7 +15,7 @@ from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const -from . import _clear_cache +from . import _clear_cache, _inject_response log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -227,7 +227,19 @@ def test_ptr_optimization(): # register zc.register_service(info) - # query + # Verify we won't respond for 1s with the same multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is None + + # Clear the cache to allow responding again + _clear_cache(zc) + + # Verify we will now respond query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( @@ -320,6 +332,7 @@ def test_unicast_response(): ) # register zc.register_service(info) + _clear_cache(zc) # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) @@ -329,8 +342,8 @@ def test_unicast_response(): assert out.id == query.id has_srv = has_txt = has_a = False nbr_additionals = 0 - nbr_answers = len(multicast_out.answers) - nbr_authorities = len(multicast_out.authorities) + nbr_answers = len(out.answers) + nbr_authorities = len(out.authorities) for answer in out.additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: @@ -360,7 +373,7 @@ def test_known_answer_supression(): zc.register_service(info) now = current_time_millis() - + _clear_cache(zc) # Test PTR supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 37985cf3..1cc6bdce 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -264,8 +264,8 @@ def __init__( self._notify_listeners: List[NotifyListener] = [] self.browsers: Dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() - self.query_handler = QueryHandler(self.registry) self.cache = DNSCache() + self.query_handler = QueryHandler(self.registry, self.cache) self.record_manager = RecordManager(self) self.condition = threading.Condition() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 590a6c96..25ea8aa5 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -24,6 +24,7 @@ import itertools from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union +from ._cache import DNSCache from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord from ._logger import log from ._services import RecordUpdateListener @@ -65,12 +66,13 @@ class RecordSetKeys(enum.Enum): class _QueryResponse: """A pair for unicast and multicast DNSOutgoing responses.""" - def __init__(self, msg: DNSIncoming, ucast_source: bool) -> None: + def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> None: """Build a query response.""" self._msg = msg self._ucast_source = ucast_source self._is_probe = msg.num_authorities > 0 self._now = current_time_millis() + self._cache = cache self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} @@ -97,6 +99,9 @@ def outgoing_unicast(self) -> Optional[DNSOutgoing]: def outgoing_multicast(self) -> Optional[DNSOutgoing]: """Build the outgoing multicast response.""" + if not self._is_probe: + self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Answers]) + self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Additionals]) return self._construct_outgoing_from_record_set(self._mcast, True) def _construct_outgoing_from_record_set( @@ -116,13 +121,26 @@ def _construct_outgoing_from_record_set( out.add_additional_answer(additional) return out + def _suppress_mcasts_from_last_second(self, records: Set[DNSRecord]) -> None: + """Remove any records that were already sent in the last second.""" + records -= set(record for record in records if self._has_mcast_record_in_last_second(record)) + + def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: + """Remove answers that were just broadcast + Protect the network against excessive packet flooding + https://datatracker.ietf.org/doc/html/rfc6762#section-14 + """ + maybe_entry = self._cache.get(record) + return bool(maybe_entry and self._now - maybe_entry.created < 1000) + class QueryHandler: """Query the ServiceRegistry.""" - def __init__(self, registry: ServiceRegistry) -> None: + def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: """Init the query handler.""" self.registry = registry + self.cache = cache def _answer_service_type_enumeration_query( self, @@ -204,7 +222,7 @@ def response( # pylint: disable=unused-argument ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: """Deal with incoming query packets. Provides a response if possible.""" ucast_source = port != _MDNS_PORT - query_res = _QueryResponse(msg, ucast_source) + query_res = _QueryResponse(self.cache, msg, ucast_source) for question in msg.questions: all_answers = self._answer_any_question(msg, question) From 1f36754f3964738e496a1da9c24380e204aaff01 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 15:00:23 -1000 Subject: [PATCH 293/608] Add is_recent property to DNSRecord (#620) - RFC 6762 defines recent as not multicast within one quarter of its TTL https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 --- tests/test_dns.py | 23 +++++++++++++++++++++++ zeroconf/_dns.py | 7 +++++++ zeroconf/const.py | 1 + 3 files changed, 31 insertions(+) diff --git a/tests/test_dns.py b/tests/test_dns.py index 3f46171c..18bffcce 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -137,6 +137,29 @@ def test_dns_outgoing_repr(self): dns_outgoing = r.DNSOutgoing(const._FLAGS_QR_QUERY) repr(dns_outgoing) + def test_dns_record_is_expired(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + now = current_time_millis() + assert record.is_expired(now) is False + assert record.is_expired(now + (8 / 2 * 1000)) is False + assert record.is_expired(now + (8 * 1000)) is True + + def test_dns_record_is_stale(self): + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + now = current_time_millis() + assert record.is_stale(now) is False + assert record.is_stale(now + (8 / 4.1 * 1000)) is False + assert record.is_stale(now + (8 / 2 * 1000)) is True + assert record.is_stale(now + (8 * 1000)) is True + + def test_dns_record_is_recent(self): + now = current_time_millis() + record = r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, 8) + assert record.is_recent(now + (8 / 4.1 * 1000)) is True + assert record.is_recent(now + (8 / 3 * 1000)) is False + assert record.is_recent(now + (8 / 2 * 1000)) is False + assert record.is_recent(now + (8 * 1000)) is False + class PacketGeneration(unittest.TestCase): def test_parse_own_packet_simple(self): diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 8eb62593..dcb8c9a3 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -42,6 +42,7 @@ _FLAGS_TC, _MAX_MSG_ABSOLUTE, _MAX_MSG_TYPICAL, + _RECENT_TIME_PERCENT, _TYPES, _TYPE_A, _TYPE_AAAA, @@ -147,6 +148,7 @@ def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) - self.created = current_time_millis() self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) + self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" @@ -183,6 +185,10 @@ def is_stale(self, now: float) -> bool: """Returns true if this record is at least half way expired.""" return self._stale_time <= now + def is_recent(self, now: float) -> bool: + """Returns true if the record more than one quarter of its TTL remaining.""" + return self._recent_time > now + def reset_ttl(self, other: 'DNSRecord') -> None: """Sets this record's TTL and created time to that of another record.""" @@ -190,6 +196,7 @@ def reset_ttl(self, other: 'DNSRecord') -> None: self.ttl = other.ttl self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) + self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use """Abstract method""" diff --git a/zeroconf/const.py b/zeroconf/const.py index 365fee09..3ec12427 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -130,6 +130,7 @@ _EXPIRE_FULL_TIME_PERCENT = 100 _EXPIRE_STALE_TIME_PERCENT = 50 _EXPIRE_REFRESH_TIME_PERCENT = 75 +_RECENT_TIME_PERCENT = 25 _LOCAL_TRAILER = '.local.' _TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' From 9a32db8582588e4bf812fd5670a7e61c50631a2e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 15:08:48 -1000 Subject: [PATCH 294/608] Add support for handling QU questions (#621) - Implements RFC 6762 sec 5.4: Questions Requesting Unicast Responses https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 --- tests/test_handlers.py | 101 +++++++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 47 +++++++++++++++++-- 2 files changed, 143 insertions(+), 5 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e0eeb353..71f6aff2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -360,6 +360,107 @@ def test_unicast_response(): zc.close() +def test_qu_response(): + """Handle multicast incoming with the QU bit set.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + other_type_ = "_notthesame._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name, other_type_) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + other_type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-other.local.", + addresses=[socket.inet_aton("10.0.4.2")], + ) + # register + zc.register_service(info) + + def _validate_complete_response(query, out): + assert out.id == query.id + has_srv = has_txt = has_a = False + nbr_additionals = 0 + nbr_answers = len(out.answers) + nbr_authorities = len(out.authorities) + for answer in out.additionals: + nbr_additionals += 1 + if answer.type == const._TYPE_SRV: + has_srv = True + elif answer.type == const._TYPE_TXT: + has_txt = True + elif answer.type == const._TYPE_A: + has_a = True + assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert has_srv and has_txt and has_a + + # With QU should respond to only unicast when the answer has been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out is None + _validate_complete_response(query, unicast_out) + + _clear_cache(zc) + # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + _validate_complete_response(query, multicast_out) + + # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + query.add_authorative_answer(info2.dns_pointer()) + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + ) + _validate_complete_response(query, unicast_out) + _validate_complete_response(query, multicast_out) + + _inject_response(zc, r.DNSIncoming(multicast_out.packets()[0])) + # With the cache repopulated; should respond to only unicast when the answer has been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + unicast_out, multicast_out = zc.query_handler.response( + r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out is None + _validate_complete_response(query, unicast_out) + # unregister + zc.unregister_service(info) + zc.close() + + def test_known_answer_supression(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownservice._tcp.local." diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 25ea8aa5..15a853b2 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -76,6 +76,25 @@ def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> Non self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} + def add_qu_question_response( + self, + answers: Set[DNSRecord], + additionals: Set[DNSRecord], + ) -> None: + """Generate a response to a multicast QU query.""" + self._add_qu_question_response_to_target(answers, RecordSetKeys.Answers) + self._add_qu_question_response_to_target(additionals, RecordSetKeys.Additionals) + + def _add_qu_question_response_to_target(self, target: Set[DNSRecord], answer_type: RecordSetKeys) -> None: + """Add part of the QU response.""" + for record in target: + if self._is_probe: + self._ucast[answer_type].add(record) + if not self._has_mcast_within_one_quarter_ttl(record): + self._mcast[answer_type].add(record) + elif not self._is_probe: + self._ucast[answer_type].add(record) + def add_ucast_question_response(self, answers: Set[DNSRecord], additionals: Set[DNSRecord]) -> None: """Generate a response to a unicast query.""" self._ucast[RecordSetKeys.Answers].update(answers) @@ -119,8 +138,23 @@ def _construct_outgoing_from_record_set( out.add_answer_at_time(answer, 0) for additional in rrset[RecordSetKeys.Additionals]: out.add_additional_answer(additional) + return out + def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: + """Check to see if a record has been mcasted recently. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + When receiving a question with the unicast-response bit set, a + responder SHOULD usually respond with a unicast packet directed back + to the querier. However, if the responder has not multicast that + record recently (within one quarter of its TTL), then the responder + SHOULD instead multicast the response so as to keep all the peer + caches up to date + """ + maybe_entry = self._cache.get(record) + return bool(maybe_entry and maybe_entry.is_recent(self._now)) + def _suppress_mcasts_from_last_second(self, records: Set[DNSRecord]) -> None: """Remove any records that were already sent in the last second.""" records -= set(record for record in records if self._has_mcast_record_in_last_second(record)) @@ -226,11 +260,14 @@ def response( # pylint: disable=unused-argument for question in msg.questions: all_answers = self._answer_any_question(msg, question) - if ucast_source: - query_res.add_ucast_question_response(*all_answers) - # We always multicast as well even if its a unicast - # source as long as we haven't done it recently (75% of ttl) - query_res.add_mcast_question_response(*all_answers) + if not ucast_source and question.unicast: + query_res.add_qu_question_response(*all_answers) + else: + if ucast_source: + query_res.add_ucast_question_response(*all_answers) + # We always multicast as well even if its a unicast + # source as long as we haven't done it recently (75% of ttl) + query_res.add_mcast_question_response(*all_answers) return query_res.outgoing_unicast(), query_res.outgoing_multicast() From 8f00cfca0e67dde6afda399da6984ed7d8f929df Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 15:47:25 -1000 Subject: [PATCH 295/608] Replace select loop with asyncio loop (#504) --- tests/__init__.py | 8 +- tests/test_asyncio.py | 34 ------ tests/test_core.py | 3 +- tests/test_init.py | 30 +---- zeroconf/_core.py | 260 +++++++++++++++++++++++------------------- zeroconf/asyncio.py | 32 ------ 6 files changed, 154 insertions(+), 213 deletions(-) delete mode 100644 tests/test_asyncio.py delete mode 100644 zeroconf/asyncio.py diff --git a/tests/__init__.py b/tests/__init__.py index 420541d7..3439a044 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -20,6 +20,7 @@ USA """ +import asyncio import socket from functools import lru_cache @@ -32,7 +33,12 @@ def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: """Inject a DNSIncoming response.""" - zc.handle_response(msg) + assert zc.loop is not None + + async def _wait_for_response(): + zc.handle_response(msg) + + asyncio.run_coroutine_threadsafe(_wait_for_response(), zc.loop).result() @lru_cache(maxsize=None) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py deleted file mode 100644 index bf4d887e..00000000 --- a/tests/test_asyncio.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - - -"""Unit tests for asyncio.py.""" - -import pytest -import threading - -from zeroconf.asyncio import AsyncZeroconf - - -@pytest.fixture(autouse=True) -def verify_threads_ended(): - """Verify that the threads are not running after the test.""" - threads_before = frozenset(threading.enumerate()) - yield - threads_after = frozenset(threading.enumerate()) - non_executor_threads = frozenset( - [ - thread - for thread in threads_after - if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name - ] - ) - threads = non_executor_threads - threads_before - assert not threads - - -@pytest.mark.asyncio -async def test_async_basic_usage() -> None: - """Test we can create and close the instance.""" - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - await aiozc.async_close() diff --git a/tests/test_core.py b/tests/test_core.py index b8a5499b..906a9508 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -46,8 +46,7 @@ def test_reaper(self): zeroconf.cache.add(record_with_1s_ttl) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) time.sleep(1) - with zeroconf.engine.condition: - zeroconf.engine._notify() + zeroconf.notify_all() time.sleep(0.1) entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) zeroconf.close() diff --git a/tests/test_init.py b/tests/test_init.py index 4710a994..6ccb9cff 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -148,15 +148,11 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.send(out) assert mocked_log_warn.call_count == call_counts[0] - # force a receive of a packet - packet = out.packets()[0] - s = zc._respond_sockets[0] - # mock the zeroconf logger and check for the correct logging backoff call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count # force receive on oversized packet - s.sendto(packet, 0, (const._MDNS_ADDR, const._MDNS_PORT)) - s.sendto(packet, 0, (const._MDNS_ADDR, const._MDNS_PORT)) + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) time.sleep(2.0) zeroconf.log.debug( 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts @@ -166,28 +162,6 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # close our zeroconf which will close the sockets zc.close() - # pop the big chunk off the end of the data and send on a closed socket - out.data.pop() - zc._GLOBAL_DONE = False - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # send on a closed socket (force a socket error) - zc.send(out) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_warn.call_count > call_counts[0] - assert mocked_log_debug.call_count > call_counts[0] - zc.send(out) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_debug.call_count > call_counts[0] + 2 - - mocked_log_warn.stop() - mocked_log_debug.stop() - def verify_name_change(self, zc, type_, name, number_hosts): desc = {'path': '/~paulsm/'} info_service = ServiceInfo( diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1cc6bdce..4668b51c 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -20,13 +20,14 @@ USA """ +import asyncio import errno +import itertools import platform -import select import socket import threading from types import TracebackType # noqa # used in type hints -from typing import Dict, List, Optional, Type, Union, cast +from typing import Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion @@ -41,6 +42,7 @@ instance_name_from_service_info, ) from ._services.registry import ServiceRegistry +from ._utils.aio import get_running_loop from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -77,83 +79,84 @@ def notify_all(self) -> None: raise NotImplementedError() -class Engine(threading.Thread): +class AsyncEngine: + """An engine wraps sockets in the event loop.""" - """An engine wraps read access to sockets, allowing objects that - need to receive data from sockets to be called back when the - sockets are ready. - - A reader needs a handle_read() method, which is called when the socket - it is interested in is ready for reading. - - Writers are not implemented here, because we only send short - packets. - """ - - def __init__(self, zc: 'Zeroconf') -> None: - threading.Thread.__init__(self) - self.daemon = True - self.zc = zc - self.readers = {} # type: Dict[socket.socket, Listener] - self.timeout = 5 - self.condition = threading.Condition() - self.socketpair = socket.socketpair() - self._last_cache_cleanup = 0.0 - self.name = "zeroconf-Engine-%s" % (getattr(self, 'native_id', self.ident),) - - def run(self) -> None: + def __init__( + self, + zeroconf: 'Zeroconf', + listen_socket: Optional[socket.socket], + respond_sockets: List[socket.socket], + ) -> None: + self.loop: Optional[asyncio.AbstractEventLoop] = None + self.zc = zeroconf + self.readers: List[asyncio.DatagramTransport] = [] + self.senders: List[asyncio.DatagramTransport] = [] + self._listen_socket = listen_socket + self._respond_sockets = respond_sockets + self._cache_cleanup_task: Optional[asyncio.Task] = None + self._running_event: Optional[asyncio.Event] = None + + def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + self.loop = loop + self._running_event = asyncio.Event() + self.loop.create_task(self._async_setup(loop_thread_ready)) + + async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: + """Set up the instance.""" + assert self.loop is not None + await self._async_create_endpoints() + self._cache_cleanup_task = self.loop.create_task(self._async_cache_cleanup()) + assert self._running_event is not None + self._running_event.set() + if loop_thread_ready: + loop_thread_ready.set() + + async def async_wait_for_start(self) -> None: + """Wait for start up.""" + assert self._running_event is not None + await self._running_event.wait() + + async def _async_create_endpoints(self) -> None: + """Create endpoints to send and receive.""" + assert self.loop is not None + loop = self.loop + reader_sockets = [] + sender_sockets = [] + if self._listen_socket: + reader_sockets.append(self._listen_socket) + for s in self._respond_sockets: + if s not in reader_sockets: + reader_sockets.append(s) + sender_sockets.append(s) + + for s in reader_sockets: + transport, _ = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) + self.readers.append(cast(asyncio.DatagramTransport, transport)) + if s in sender_sockets: + self.senders.append(cast(asyncio.DatagramTransport, transport)) + + async def _async_cache_cleanup(self) -> None: + """Periodic cache cleanup.""" while not self.zc.done: - try: - rr, _wr, _er = select.select([*self.readers.keys(), self.socketpair[0]], [], [], self.timeout) - - if self.zc.done: - return - - for socket_ in rr: - reader = self.readers.get(socket_) - if reader: - reader.handle_read(socket_) - - if self.socketpair[0] in rr: - # Clear the socket's buffer - self.socketpair[0].recv(128) - - except (select.error, socket.error) as e: - # If the socket was closed by another thread, during - # shutdown, ignore it and exit - if e.args[0] not in (errno.EBADF, errno.ENOTCONN) or not self.zc.done: - raise - now = current_time_millis() - if now - self._last_cache_cleanup >= _CACHE_CLEANUP_INTERVAL: - self._last_cache_cleanup = now - self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) - self.zc.record_manager.updates_complete() - - self.socketpair[0].close() - self.socketpair[1].close() - - def _notify(self) -> None: - self.condition.notify() - try: - self.socketpair[1].send(b'x') - except socket.error: - # The socketpair may already be closed during shutdown, ignore it - if not self.zc.done: - raise - - def add_reader(self, reader: 'Listener', socket_: socket.socket) -> None: - with self.condition: - self.readers[socket_] = reader - self._notify() + self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) + self.zc.record_manager.updates_complete() + await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) - def del_reader(self, socket_: socket.socket) -> None: - with self.condition: - del self.readers[socket_] - self._notify() + def close(self) -> None: + """Close the engine.""" + if self._cache_cleanup_task: + self._cache_cleanup_task.cancel() + self._cache_cleanup_task = None + for transport in itertools.chain(self.senders, self.readers): + transport.close() + for s in self._respond_sockets: + s.close() -class Listener(QuietLogger): +class AsyncListener(asyncio.Protocol, QuietLogger): """A Listener is used by this module to listen on the multicast group to which DNS messages are sent, allowing the implementation @@ -165,12 +168,18 @@ class Listener(QuietLogger): def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.data = None # type: Optional[bytes] + self.transport: Optional[asyncio.DatagramTransport] = None + super().__init__() - def handle_read(self, socket_: socket.socket) -> None: - try: - data, (addr, port, *_v6) = socket_.recvfrom(_MAX_MSG_ABSOLUTE) - except Exception: # pylint: disable=broad-except - self.log_exception_warning('Error reading from socket %d', socket_.fileno()) + def datagram_received( + self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] + ) -> None: + assert self.transport is not None + if len(addrs) == 2: + addr, port = addrs # type: ignore + elif len(addrs) == 4: + addr, port, _flow, _scope = addrs # type: ignore + else: return if self.data == data: @@ -178,7 +187,7 @@ def handle_read(self, socket_: socket.socket) -> None: 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', addr, port, - socket_.fileno(), + self.transport.get_extra_info('socket').fileno(), len(data), data, ) @@ -191,7 +200,7 @@ def handle_read(self, socket_: socket.socket) -> None: 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', addr, port, - socket_.fileno(), + self.transport.get_extra_info('socket').fileno(), msg, len(data), data, @@ -201,7 +210,7 @@ def handle_read(self, socket_: socket.socket) -> None: 'Received from %r:%r (socket %d): (%d bytes) [%r]', addr, port, - socket_.fileno(), + self.transport.get_extra_info('socket').fileno(), len(data), data, ) @@ -215,6 +224,12 @@ def handle_read(self, socket_: socket.socket) -> None: self.zc.handle_query(msg, addr, port) + def error_received(self, exc: Exception) -> None: + """Likely socket closed or IPv6.""" + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + self.transport = cast(asyncio.DatagramTransport, transport) + class Zeroconf(QuietLogger): @@ -250,16 +265,14 @@ def __init__( # hook for threads self._GLOBAL_DONE = False - self.unicast = unicast if apple_p2p and not platform.system() == 'Darwin': raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') - self._listen_socket, self._respond_sockets = create_sockets( - interfaces, unicast, ip_version, apple_p2p=apple_p2p - ) - log.debug('Listen socket %s, respond sockets %s', self._listen_socket, self._respond_sockets) - self.multi_socket = unicast or interfaces is not InterfaceChoice.Default + listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) + log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets) + + self.engine = AsyncEngine(self, listen_socket, respond_sockets) self._notify_listeners: List[NotifyListener] = [] self.browsers: Dict[ServiceListener, ServiceBrowser] = {} @@ -269,18 +282,36 @@ def __init__( self.record_manager = RecordManager(self) self.condition = threading.Condition() + self.loop: Optional[asyncio.AbstractEventLoop] = None + self._loop_thread: Optional[threading.Thread] = None - self.engine = Engine(self) - self.listener = Listener(self) - if not unicast: - self.engine.add_reader(self.listener, cast(socket.socket, self._listen_socket)) - if self.multi_socket: - for s in self._respond_sockets: - self.engine.add_reader(self.listener, s) - # Start the engine only after all - # the readers have been added to avoid - # missing any packets that are on the wire - self.engine.start() + self.start() + + def start(self) -> None: + """Start Zeroconf.""" + self.loop = get_running_loop() + if self.loop: + self.engine.setup(self.loop, None) + return + self._start_thread() + + def _start_thread(self) -> None: + """Start a thread with a running event loop.""" + loop_thread_ready = threading.Event() + + def _run_loop() -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.engine.setup(self.loop, loop_thread_ready) + self.loop.run_forever() + + self._loop_thread = threading.Thread(target=_run_loop, daemon=True) + self._loop_thread.start() + loop_thread_ready.wait() + + async def async_wait_for_start(self) -> None: + """Wait for start up.""" + await self.engine.async_wait_for_start() @property def done(self) -> bool: @@ -504,11 +535,16 @@ def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None possible.""" unicast_out, multicast_out = self.query_handler.response(msg, addr, port) if unicast_out and unicast_out.answers: - self.send(unicast_out, addr, port) + self.async_send(unicast_out, addr, port) if multicast_out and multicast_out.answers: - self.send(multicast_out, None, _MDNS_PORT) + self.async_send(multicast_out, None, _MDNS_PORT) def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + """Sends an outgoing packet threadsafe.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_send, out, addr, port) + + def async_send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: """Sends an outgoing packet.""" for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: @@ -523,9 +559,10 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P out, packet, ) - for s in self._respond_sockets: + for transport in self.engine.senders: if self._GLOBAL_DONE: return + s = transport.get_extra_info('socket') try: if addr is None: real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR @@ -533,7 +570,7 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P continue else: real_addr = addr - bytes_sent = s.sendto(packet, 0, (real_addr, port)) + transport.sendto(packet, (real_addr, port or _MDNS_PORT)) except OSError as exc: if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: # with IPv6 we don't have a reliable way to determine if an interface actually has @@ -544,9 +581,6 @@ def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_P except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions # on send errors, log the exception and keep going self.log_exception_warning('Error sending through socket %d', s.fileno()) - else: - if bytes_sent != len(packet): - self.log_warning_once('!!! sent %d of %d bytes to %r' % (bytes_sent, len(packet), s)) def close(self) -> None: """Ends the background threads, and prevent this instance from @@ -557,19 +591,13 @@ def close(self) -> None: self.remove_all_service_listeners() self.unregister_all_services() self._GLOBAL_DONE = True - - # shutdown recv socket and thread - if not self.unicast: - self.engine.del_reader(cast(socket.socket, self._listen_socket)) - cast(socket.socket, self._listen_socket).close() - if self.multi_socket: - for s in self._respond_sockets: - self.engine.del_reader(s) - self.engine.join() + self.engine.close() # shutdown the rest self.notify_all() - for s in self._respond_sockets: - s.close() + if self._loop_thread: + assert self.loop is not None + self.loop.call_soon_threadsafe(self.loop.stop) + self._loop_thread.join() def __enter__(self) -> 'Zeroconf': return self diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py deleted file mode 100644 index 3de171f7..00000000 --- a/zeroconf/asyncio.py +++ /dev/null @@ -1,32 +0,0 @@ -""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine - Copyright 2003 Paul Scott-Murphy, 2014 William McBrine - - This module provides a framework for the use of DNS Service Discovery - using IP multicast. - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 - USA -""" - -from ._logger import log -from .aio import AsyncZeroconf # pylint: disable=unused-import # noqa - -# The asyncio module would shadow system asyncio in some import cases -# to resolve this, the module has been renamed zeroconf.aio - -log.warning( - "zeroconf.asyncio namespace has changed to zeroconf.aio; " - "This compatibility module will be removed in the next version" -) From f15e84f3ee7a644792fe98edde84dd216b3497cb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 15:58:46 -1000 Subject: [PATCH 296/608] Eliminate aio sender thread (#622) --- zeroconf/aio.py | 52 +++++++------------------------------------------ 1 file changed, 7 insertions(+), 45 deletions(-) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 3df58eae..6f445e72 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -21,13 +21,10 @@ """ import asyncio import contextlib -import queue -import threading from types import TracebackType # noqa # used in type hints from typing import Awaitable, Callable, Dict, List, Optional, Type, Union from ._core import NotifyListener, Zeroconf -from ._dns import DNSOutgoing from ._exceptions import NonUniqueNameException from ._services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info from ._utils.aio import wait_condition_or_timeout @@ -36,42 +33,6 @@ from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME -def _get_best_available_queue() -> queue.Queue: - """Create the best available queue type.""" - if hasattr(queue, "SimpleQueue"): - return queue.SimpleQueue() # type: ignore # pylint: disable=all - return queue.Queue() - - -class _AsyncSender(threading.Thread): - """A thread to handle sending DNSOutgoing for asyncio.""" - - def __init__(self, zc: 'Zeroconf'): - """Create the sender thread.""" - super().__init__() - self.zc = zc - self.queue = _get_best_available_queue() - self.start() - self.name = "AsyncZeroconfSender" - - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: - """Queue a send to be processed by the thread.""" - self.queue.put((out, addr, port)) - - def close(self) -> None: - """Close the instance.""" - self.queue.put(None) - self.join() - - def run(self) -> None: - """Runner that processes sends FIFO.""" - while True: - event = self.queue.get() - if event is None: - return - self.zc.send(*event) - - class AsyncNotifyListener(NotifyListener): """A NotifyListener that async code can use to wait for events.""" @@ -115,6 +76,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: delay = _LISTENER_TIME next_ = now last = now + timeout + await aiozc.zeroconf.async_wait_for_start() try: aiozc.zeroconf.add_listener(self, None) while not self._is_complete: @@ -124,7 +86,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: out = self.generate_request_query(aiozc.zeroconf, now) if not out.questions: return self.load_from_cache(aiozc.zeroconf) - aiozc.sender.send(out) + aiozc.zeroconf.async_send(out) next_ = now + delay delay *= 2 @@ -180,7 +142,7 @@ async def async_run(self) -> None: out = self.generate_ready_queries() if out: - self.aiozc.sender.send(out, addr=self.addr, port=self.port) + self.aiozc.zeroconf.async_send(out, addr=self.addr, port=self.port) if not self._handlers_to_call: continue @@ -236,7 +198,6 @@ def __init__( self.async_notify = AsyncNotifyListener(self) self.zeroconf.add_notify_listener(self.async_notify) self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} - self.sender = _AsyncSender(self.zeroconf) self.condition = asyncio.Condition() async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: @@ -244,7 +205,7 @@ async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: for i in range(3): if i != 0: await asyncio.sleep(millis_to_seconds(interval)) - self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl)) + self.zeroconf.async_send(self.zeroconf.generate_service_broadcast(info, ttl)) async def async_register_service( self, @@ -261,6 +222,7 @@ async def async_register_service( The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ + await self.zeroconf.async_wait_for_start() await self.async_check_service(info, cooperating_responders) self.zeroconf.registry.add(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) @@ -274,7 +236,7 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b for i in range(3): if i != 0: await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) - self.sender.send(self.zeroconf.generate_service_query(info)) + self.zeroconf.async_send(self.zeroconf.generate_service_query(info)) self._raise_on_name_conflict(info) def _raise_on_name_conflict(self, info: ServiceInfo) -> None: @@ -304,13 +266,13 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: def _close(self) -> None: """Shutdown zeroconf and the sender.""" - self.sender.close() self.zeroconf.remove_notify_listener(self.async_notify) self.zeroconf.close() async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" + await self.zeroconf.async_wait_for_start() await self.async_remove_all_service_listeners() await self.loop.run_in_executor(None, self._close) From 4d05961088efa8b503cad5658afade874eaeec76 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 16:00:02 -1000 Subject: [PATCH 297/608] Update changelog (#623) --- README.rst | 93 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/README.rst b/README.rst index dc944425..20d37bfa 100644 --- a/README.rst +++ b/README.rst @@ -134,35 +134,21 @@ See examples directory for more. Changelog ========= -0.33.0 (Unreleased) -=================== - -* Breaking change: zeroconf.asyncio has been removed in favor of zeroconf.aio - TBD - - The asyncio name could shadow system asyncio in some cases. If - zeroconf is in sys.path, this would result in loading zeroconf.asyncio - when system asyncio was intended. - 0.32.0 (Unreleased) =================== -* Breaking change: zeroconf.asyncio has been renamed zeroconf.aio (#503) @bdraco +* BREAKING CHANGE: zeroconf.asyncio has been renamed zeroconf.aio (#503) @bdraco The asyncio name could shadow system asyncio in some cases. If zeroconf is in sys.path, this would result in loading zeroconf.asyncio when system asyncio was intended. - An `zeroconf.asyncio` shim module has been added that imports `zeroconf.aio` - that was available in 0.31 to provide backwards compatibility in 0.32.0 - This module will be removed in 0.33.0 to fix the underlying problem - detailed in #502 - -* Breaking change: Update internal version check to match docs (3.6+) (#491) @bdraco +* BREAKING CHANGE: Update internal version check to match docs (3.6+) (#491) @bdraco Python version eariler then 3.6 were likely broken with zeroconf already, however the version is now explictly checked. -* Breaking change: RecordUpdateListener now uses update_records instead of update_record (#419) @bdraco +* BREAKING CHANGE: RecordUpdateListener now uses update_records instead of update_record (#419) @bdraco This allows the listener to receive all the records that have been updated in a single transaction such as a packet or @@ -181,7 +167,7 @@ Changelog has been updated as its a common pattern to call for ServiceInfo when a ServiceBrowser handler fires. -* Breaking change: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco +* BREAKING CHANGE: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco When manually creating a zeroconf.Engine object, it is no longer started automatically. It must manually be started by calling .start() on the created object. @@ -189,7 +175,7 @@ Changelog The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. -* Breaking change: Remove DNSOutgoing.packet backwards compatibility (#569) @bdraco +* BREAKING CHANGE: Remove DNSOutgoing.packet backwards compatibility (#569) @bdraco DNSOutgoing.packet only returned a partial message when the DNSOutgoing contents exceeded _MAX_MSG_ABSOLUTE or _MAX_MSG_TYPICAL @@ -198,6 +184,75 @@ Changelog should not be used since it will end up missing data, it has been removed +* TRAFFIC REDUCTION: Add support for handling QU questions (#621) @bdraco + + Implements RFC 6762 sec 5.4: + Questions Requesting Unicast Responses + datatracker.ietf.org/doc/html/rfc6762#section-5.4 + +* TRAFFIC REDUCTION: Protect the network against excessive packet flooding (#619) @bdraco + +* TRAFFIC REDUCTION: Suppress additionals when they are already in the answers section (#617) @bdraco + +* TRAFFIC REDUCTION: Avoid including additionals when the answer is suppressed by known-answer supression (#614) @bdraco + +* MAJOR BUG: Ensure matching PTR queries are returned with the ANY query (#618) @bdraco + +* MAJOR BUG: Fix lookup of uppercase names in registry (#597) @bdraco + + If the ServiceInfo was registered with an uppercase name and the query was + for a lowercase name, it would not be found and vice-versa. + +* MAJOR BUG: Ensure unicast responses can be sent to any source port (#598) @bdraco + + Unicast responses were only being sent if the source port + was 53, this prevented responses when testing with dig: + + dig -p 5353 @224.0.0.251 media-12.local + + The above query will now see a response + +* MAJOR BUG: Fix queries for AAAA records (#616) @bdraco + +* Eliminate aio sender thread (#622) @bdraco + +* Replace select loop with asyncio loop (#504) @bdraco + +* Add is_recent property to DNSRecord (#620) @bdraco + + RFC 6762 defines recent as not multicast within one quarter of its TTL + datatracker.ietf.org/doc/html/rfc6762#section-5.4 + +* Breakout the query response handler into its own class (#615) @bdraco + +* Add the ability for ServiceInfo.dns_addresses to filter by address type (#612) @bdraco + +* Make DNSRecords hashable (#611) @bdraco + + Allows storing them in a set for de-duplication + + Needed to be able to check for duplicates to solve #604 + +* Ensure the QU bit is set for probe queries (#609) @bdraco + + The bit should be set per + datatracker.ietf.org/doc/html/rfc6762#section-8.1 + +* Log destination when sending packets (#606) @bdraco + +* Fix docs version to match readme (cpython 3.6+) (#602) @bdraco + +* Add ZeroconfServiceTypes to zeroconf.__all__ (#601) @bdraco + + This class is in the readme, but is not exported by + default + +* Add id_ param to allow setting the id in the DNSOutgoing constructor (#599) @bdraco + +* Add unicast property to DNSQuestion to determine if the QU bit is set (#593) @bdraco + +* Reduce branching in DNSOutgoing.add_answer_at_time (#592) @bdraco + * Breakout DNSCache into zeroconf.cache (#568) @bdraco * Removed protected imports from zeroconf namespace (#567) @bdraco From 42d53c7c04a7bbf4e60e691e2e58fe7acfec8ad9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 16:27:30 -1000 Subject: [PATCH 298/608] Ensure zeroconf can be loaded when the system disables IPv6 (#624) --- zeroconf/const.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/zeroconf/const.py b/zeroconf/const.py index 3ec12427..96f536df 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -38,7 +38,10 @@ _MDNS_ADDR = '224.0.0.251' _MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) _MDNS_ADDR6 = 'ff02::fb' -_MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) +try: + _MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) +except OSError: # can't use AF_INET6, IPv6 is disabled + pass _MDNS_PORT = 5353 _DNS_PORT = 53 _DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 From 5750f7ceef0441fe1cedc0d96e7ef5ccc232d875 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 17:52:06 -1000 Subject: [PATCH 299/608] Fix random test failures due to monkey patching not being undone between tests (#626) - Switch patching to use unitest.mock.patch to ensure the patch is reverted when the test is completed Fixes #505 --- tests/test_init.py | 135 +++++------ tests/test_services.py | 526 ++++++++++++++++++++--------------------- 2 files changed, 329 insertions(+), 332 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index 6ccb9cff..6e5457ff 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -81,7 +81,7 @@ def test_lots_of_names(self): # verify that name changing works self.verify_name_change(zc, type_, name, server_count) - # we are going to monkey patch the zeroconf send to check packet sizes + # we are going to patch the zeroconf send to check packet sizes old_send = zc.send longest_packet_len = 0 @@ -96,71 +96,74 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): longest_packet = out old_send(out, addr=addr, port=port) - # monkey patch the zeroconf send - setattr(zc, "send", send) - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - # start a browser - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - # wait until the browse request packet has maxed out in size - sleep_count = 0 - # we will never get to this large of a packet given the application-layer - # splitting of packets, but we still want to track the longest_packet_len - # for the debug message below - while sleep_count < 100 and longest_packet_len < const._MAX_MSG_ABSOLUTE - 100: - sleep_count += 1 - time.sleep(0.1) - - browser.cancel() - time.sleep(0.5) - - import zeroconf - - zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) - - # now the browser has sent at least one request, verify the size - assert longest_packet_len <= const._MAX_MSG_TYPICAL - assert longest_packet_len >= const._MAX_MSG_TYPICAL - 100 - - # mock zeroconf's logger warning() and debug() - from unittest.mock import patch - - patch_warn = patch('zeroconf._logger.log.warning') - patch_debug = patch('zeroconf._logger.log.debug') - mocked_log_warn = patch_warn.start() - mocked_log_debug = patch_debug.start() - - # now that we have a long packet in our possession, let's verify the - # exception handling. - out = longest_packet - assert out is not None - out.data.append(b'\0' * 1000) - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # try to send an oversized packet - zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] - zc.send(out) - assert mocked_log_warn.call_count == call_counts[0] - - # mock the zeroconf logger and check for the correct logging backoff - call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count - # force receive on oversized packet - zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) - zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) - time.sleep(2.0) - zeroconf.log.debug( - 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, call_counts - ) - assert mocked_log_debug.call_count > call_counts[0] - - # close our zeroconf which will close the sockets - zc.close() + # patch the zeroconf send + with unittest.mock.patch.object(zc, "send", send): + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + # wait until the browse request packet has maxed out in size + sleep_count = 0 + # we will never get to this large of a packet given the application-layer + # splitting of packets, but we still want to track the longest_packet_len + # for the debug message below + while sleep_count < 100 and longest_packet_len < const._MAX_MSG_ABSOLUTE - 100: + sleep_count += 1 + time.sleep(0.1) + + browser.cancel() + time.sleep(0.5) + + import zeroconf + + zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) + + # now the browser has sent at least one request, verify the size + assert longest_packet_len <= const._MAX_MSG_TYPICAL + assert longest_packet_len >= const._MAX_MSG_TYPICAL - 100 + + # mock zeroconf's logger warning() and debug() + from unittest.mock import patch + + patch_warn = patch('zeroconf._logger.log.warning') + patch_debug = patch('zeroconf._logger.log.debug') + mocked_log_warn = patch_warn.start() + mocked_log_debug = patch_debug.start() + + # now that we have a long packet in our possession, let's verify the + # exception handling. + out = longest_packet + assert out is not None + out.data.append(b'\0' * 1000) + + # mock the zeroconf logger and check for the correct logging backoff + call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count + # try to send an oversized packet + zc.send(out) + assert mocked_log_warn.call_count == call_counts[0] + zc.send(out) + assert mocked_log_warn.call_count == call_counts[0] + + # mock the zeroconf logger and check for the correct logging backoff + call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count + # force receive on oversized packet + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) + zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) + time.sleep(2.0) + zeroconf.log.debug( + 'warn %d debug %d was %s', + mocked_log_warn.call_count, + mocked_log_debug.call_count, + call_counts, + ) + assert mocked_log_debug.call_count > call_counts[0] + + # close our zeroconf which will close the sockets + zc.close() def verify_name_change(self, zc, type_, name, number_hosts): desc = {'path': '/~paulsm/'} diff --git a/tests/test_services.py b/tests/test_services.py index 6677bfcb..04300602 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -221,119 +221,119 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): last_sent = out send_event.set() - # monkey patch the zeroconf send - setattr(zc, "send", send) + # patch the zeroconf send + with unittest.mock.patch.object(zc, "send", send): - def mock_incoming_msg(records) -> r.DNSIncoming: + def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - for record in records: - generated.add_answer_at_time(record, 0) + for record in records: + generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packets()[0]) + return r.DNSIncoming(generated.packets()[0]) - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for SRV, A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - service_text, - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 3 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 2 - assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - last_sent = None - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for SRV, A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 3 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 2 + assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + last_sent = None + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() def test_get_info_single(self): @@ -358,83 +358,83 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): last_sent = out send_event.set() - # monkey patch the zeroconf send - setattr(zc, "send", send) + # patch the zeroconf send + with unittest.mock.patch.object(zc, "send", send): - def mock_incoming_msg(records) -> r.DNSIncoming: + def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - for record in records: - generated.add_answer_at_time(record, 0) + for record in records: + generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packets()[0]) + return r.DNSIncoming(generated.packets()[0]) - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - service_text, - ), - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ), - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ), - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ), + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() class TestServiceBrowserMultipleTypes(unittest.TestCase): @@ -953,7 +953,7 @@ def test_backoff(): type_ = "_http._tcp.local." zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - # we are going to monkey patch the zeroconf send to check query transmission + # we are going to patch the zeroconf send to check query transmission old_send = zeroconf_browser.send time_offset = 0.0 @@ -969,55 +969,52 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): got_query.set() old_send(out, addr=addr, port=port) - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) - - # monkey patch the zeroconf current_time_millis - s.current_time_millis = current_time_millis + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to prevent test running forever + with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + s, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(s, "_BROWSER_BACKOFF_LIMIT", 10): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass - # monkey patch the backoff limit to prevent test running forever - s._BROWSER_BACKOFF_LIMIT = 10 # seconds + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - try: - # Test that queries are sent at increasing intervals - sleep_count = 0 - next_query_interval = 0.0 - expected_query_time = 0.0 - while True: - sleep_count += 1 - for _ in range(2): - # If the browser thread is starting up - # its possible we notify before the initial sleep - # which means the test will fail so we need to d - # this twice to eliminate the race condition - zeroconf_browser.notify_all() - got_query.wait(0.05) - if time_offset == expected_query_time: - assert got_query.is_set() - got_query.clear() - if next_query_interval == s._BROWSER_BACKOFF_LIMIT: - # Only need to test up to the point where we've seen a query - # after the backoff limit has been hit - break - elif next_query_interval == 0: - next_query_interval = initial_query_interval - expected_query_time = initial_query_interval + try: + # Test that queries are sent at increasing intervals + sleep_count = 0 + next_query_interval = 0.0 + expected_query_time = 0.0 + while True: + sleep_count += 1 + for _ in range(2): + # If the browser thread is starting up + # its possible we notify before the initial sleep + # which means the test will fail so we need to d + # this twice to eliminate the race condition + zeroconf_browser.notify_all() + got_query.wait(0.05) + if time_offset == expected_query_time: + assert got_query.is_set() + got_query.clear() + if next_query_interval == s._BROWSER_BACKOFF_LIMIT: + # Only need to test up to the point where we've seen a query + # after the backoff limit has been hit + break + elif next_query_interval == 0: + next_query_interval = initial_query_interval + expected_query_time = initial_query_interval + else: + next_query_interval = min(2 * next_query_interval, s._BROWSER_BACKOFF_LIMIT) + expected_query_time += next_query_interval else: - next_query_interval = min(2 * next_query_interval, s._BROWSER_BACKOFF_LIMIT) - expected_query_time += next_query_interval - else: - assert not got_query.is_set() - time_offset += initial_query_interval + assert not got_query.is_set() + time_offset += initial_query_interval - finally: - browser.cancel() - zeroconf_browser.close() + finally: + browser.cancel() + zeroconf_browser.close() def test_integration(): @@ -1038,7 +1035,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - # we are going to monkey patch the zeroconf send to check packet sizes + # we are going to patch the zeroconf send to check packet sizes old_send = zeroconf_browser.send time_offset = 0.0 @@ -1063,54 +1060,51 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): got_query.set() old_send(out, addr=addr, port=port) - # monkey patch the zeroconf send - setattr(zeroconf_browser, "send", send) + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL + with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + s, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(s, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): + service_added = Event() + service_removed = Event() - # monkey patch the zeroconf current_time_millis - s.current_time_millis = current_time_millis + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - # monkey patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - s._BROWSER_BACKOFF_LIMIT = int(expected_ttl / 4) + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zeroconf_registrar.register_service(info) - service_added = Event() - service_removed = Event() + try: + service_added.wait(1) + assert service_added.is_set() - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + # Test that we receive queries containing answers only if the remaining TTL + # is greater than half the original TTL + sleep_count = 0 + test_iterations = 50 + while nbr_answers < test_iterations: + # Increase simulated time shift by 1/4 of the TTL in seconds + time_offset += expected_ttl / 4 + zeroconf_browser.notify_all() + sleep_count += 1 + got_query.wait(0.1) + got_query.clear() + # Prevent the test running indefinitely in an error condition + assert sleep_count < test_iterations * 4 + assert not unexpected_ttl.is_set() - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - zeroconf_registrar.register_service(info) - - try: - service_added.wait(1) - assert service_added.is_set() - - # Test that we receive queries containing answers only if the remaining TTL - # is greater than half the original TTL - sleep_count = 0 - test_iterations = 50 - while nbr_answers < test_iterations: - # Increase simulated time shift by 1/4 of the TTL in seconds - time_offset += expected_ttl / 4 - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - got_query.clear() - # Prevent the test running indefinitely in an error condition - assert sleep_count < test_iterations * 4 - assert not unexpected_ttl.is_set() - - # Don't remove service, allow close() to cleanup - - finally: - zeroconf_registrar.close() - service_removed.wait(1) - assert service_removed.is_set() - browser.cancel() - zeroconf_browser.close() + # Don't remove service, allow close() to cleanup + + finally: + zeroconf_registrar.close() + service_removed.wait(1) + assert service_removed.is_set() + browser.cancel() + zeroconf_browser.close() def test_legacy_record_update_listener(): From 113874a7b59ac9cc887b1b626ac1486781c7d56f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 17:58:50 -1000 Subject: [PATCH 300/608] Add test to ensure ServiceBrowser sees port change as an update (#625) --- tests/test_services.py | 55 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/test_services.py b/tests/test_services.py index 04300602..db1f7579 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1184,3 +1184,58 @@ def dns_addresses_to_addresses(dns_address: List[DNSAddress]): assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6] assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4] assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] + + +def test_service_browser_is_aware_of_port_changes(): + """Test that the ServiceBrowser is aware of port changes.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + callbacks = [] + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + nonlocal callbacks + if name == registration_name: + callbacks.append((service_type, state_change, name)) + + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + zc.wait(100) + + assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')] + assert zc.get_service_info(type_, registration_name).port == 80 + + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + zc.wait(100) + + assert callbacks == [ + ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'), + ('_hap._tcp.local.', ServiceStateChange.Updated, 'xxxyyy._hap._tcp.local.'), + ] + assert zc.get_service_info(type_, registration_name).port == 400 + browser.cancel() + + zc.close() From 215d6badb3db796b13a000b26953cb57c557e5e5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 17:58:59 -1000 Subject: [PATCH 301/608] Update changelog (#627) --- README.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.rst b/README.rst index 20d37bfa..e12d11f8 100644 --- a/README.rst +++ b/README.rst @@ -214,6 +214,15 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Add test to ensure ServiceBrowser sees port change as an update (#625) @bdraco + +* Fix random test failures due to monkey patching not being undone between tests (#626) @bdraco + + Switch patching to use unitest.mock.patch to ensure the patch + is reverted when the test is completed + +* Ensure zeroconf can be loaded when the system disables IPv6 (#624) @bdraco + * Eliminate aio sender thread (#622) @bdraco * Replace select loop with asyncio loop (#504) @bdraco From 28a614e0586a0ca1c5c1651b59c9a4d9c1af9a1b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 21:13:37 -1000 Subject: [PATCH 302/608] Return early on invalid data received (#628) - Improve coverage for handling invalid incoming data --- tests/test_core.py | 31 +++++++++++++++++++++++++++++++ tests/test_dns.py | 10 ++++++++++ zeroconf/_core.py | 6 ++---- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 906a9508..97457592 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -286,3 +286,34 @@ def test_generate_service_query_set_qu_bit(): out = zeroconf_registrar.generate_service_query(info) assert out.questions[0].unicast is True zeroconf_registrar.close() + + +def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): + """Ensure an invalid packet cannot cause the loop to collapse.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + generated = r.DNSOutgoing(0) + packet = generated.packets()[0] + packet = packet[:8] + b'deadbeef' + packet[8:] + parsed = r.DNSIncoming(packet) + assert parsed.valid is False + + mock_out = unittest.mock.Mock() + mock_out.packets = lambda: [packet] + zc.send(mock_out) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + entry = r.DNSText( + "didnotcrashincoming._crash._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + assert isinstance(entry, r.DNSText) + assert isinstance(entry, r.DNSRecord) + assert isinstance(entry, r.DNSEntry) + + generated.add_answer_at_time(entry, 0) + zc.send(generated) + time.sleep(0.2) + zc.close() + assert zc.cache.get(entry) is not None diff --git a/tests/test_dns.py b/tests/test_dns.py index 18bffcce..664fccde 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -586,6 +586,16 @@ def test_incoming_unknown_type(self): assert len(parsed.answers) == 0 assert parsed.is_query() != parsed.is_response() + def test_incoming_circular_reference(self): + assert not r.DNSIncoming( + bytes.fromhex( + '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb' + '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e' + '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966' + '792d636f6e6e656374045f746370c023' + ) + ).valid + def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 4668b51c..1fe5907b 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -214,11 +214,9 @@ def datagram_received( len(data), data, ) + return - if not msg.valid: - pass - - elif not msg.is_query(): + if not msg.is_query(): self.zc.handle_response(msg) return From 2065b1d7ec7cb5d41c34826c2d8887bdd8a018b6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 21:32:38 -1000 Subject: [PATCH 303/608] Add test for wait_condition_or_timeout_times_out util (#630) --- tests/utils/test_aio.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index a74d991d..65eaf255 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -5,6 +5,7 @@ """Unit tests for zeroconf._utils.aio.""" import asyncio +import contextlib import pytest @@ -20,3 +21,25 @@ async def test_get_running_loop_from_async() -> None: def test_get_running_loop_no_loop() -> None: """Test we get None when there is no loop running.""" assert aioutils.get_running_loop() is None + + +@pytest.mark.asyncio +async def test_wait_condition_or_timeout_times_out() -> None: + """Test wait_condition_or_timeout will timeout.""" + test_cond = asyncio.Condition() + async with test_cond: + await aioutils.wait_condition_or_timeout(test_cond, 0.1) + + async def _hold_condition(): + async with test_cond: + await test_cond.wait() + + task = asyncio.ensure_future(_hold_condition()) + await asyncio.sleep(0.1) + + async with test_cond: + await aioutils.wait_condition_or_timeout(test_cond, 0.1) + + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task From 2b31612e3f128b1193da9e0d2640f4e93fab2e3a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 21:32:54 -1000 Subject: [PATCH 304/608] Remove unreachable cache check for DNSAddresses (#629) - The ServiceBrowser would check to see if a DNSAddress was already in the cache and return early to avoid sending updates when the address already was held in the cache. This check was not needed since there is already a check a few lines before as `self.zc.cache.get(record)` which effectively does the same thing. This lead to the check never being covered in the tests and 2 cache lookups when only one was needed. --- zeroconf/_services/__init__.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index f6092aae..8a6d7d06 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -248,12 +248,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def _process_record_update( - self, - zc: 'Zeroconf', - now: float, - record: DNSRecord, - ) -> None: + def _process_record_update(self, now: float, record: DNSRecord) -> None: """Process a single record update from a batch of updates.""" expired = record.is_expired(now) @@ -281,14 +276,6 @@ def _process_record_update( return if isinstance(record, DNSAddress): - # Only trigger an updated event if the address is new - if record.address in set( - service.address - for service in zc.cache.entries_with_name(record.name) - if isinstance(service, DNSAddress) - ): - return - # Iterate through the DNSCache and callback any services that use this address for service in self.zc.cache.entries_with_server(record.name): type_ = self._record_matching_type(service) @@ -310,7 +297,7 @@ def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) - Ensures that there is are no unecessary duplicates in the list. """ for record in records: - self._process_record_update(zc, now, record) + self._process_record_update(now, record) def update_records_complete(self) -> None: """Called when a record update has completed for all handlers. From 64f6dd7e244c86d58b962f48a50d07625f2a2a33 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 21:34:42 -1000 Subject: [PATCH 305/608] Update changelog (#631) --- README.rst | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.rst b/README.rst index e12d11f8..e062483a 100644 --- a/README.rst +++ b/README.rst @@ -214,6 +214,23 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Remove unreachable cache check for DNSAddresses (#629) @bdraco + + The ServiceBrowser would check to see if a DNSAddress was + already in the cache and return early to avoid sending + updates when the address already was held in the cache. + This check was not needed since there is already a check + a few lines before as `self.zc.cache.get(record)` which + effectively does the same thing. This lead to the check + never being covered in the tests and 2 cache lookups when + only one was needed. + +* Add test for wait_condition_or_timeout_times_out util (#630) @bdraco + +* Return early on invalid data received (#628) @bdraco + + Improve coverage for handling invalid incoming data + * Add test to ensure ServiceBrowser sees port change as an update (#625) @bdraco * Fix random test failures due to monkey patching not being undone between tests (#626) @bdraco From 4ce33e48e2094f17d8358cf221c7e2f9a8cb3568 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:04:53 -1000 Subject: [PATCH 306/608] Return early in the shutdown/close process (#632) --- zeroconf/_core.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1fe5907b..577da9b8 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -147,13 +147,14 @@ async def _async_cache_cleanup(self) -> None: def close(self) -> None: """Close the engine.""" - if self._cache_cleanup_task: - self._cache_cleanup_task.cancel() - self._cache_cleanup_task = None for transport in itertools.chain(self.senders, self.readers): transport.close() for s in self._respond_sockets: s.close() + if not self._cache_cleanup_task: + return + self._cache_cleanup_task.cancel() + self._cache_cleanup_task = None class AsyncListener(asyncio.Protocol, QuietLogger): @@ -592,10 +593,11 @@ def close(self) -> None: self.engine.close() # shutdown the rest self.notify_all() - if self._loop_thread: - assert self.loop is not None - self.loop.call_soon_threadsafe(self.loop.stop) - self._loop_thread.join() + if not self._loop_thread: + return + assert self.loop is not None + self.loop.call_soon_threadsafe(self.loop.stop) + self._loop_thread.join() def __enter__(self) -> 'Zeroconf': return self From 5f66caaccf44c1504988cb82c1cba78d28dde7e7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:13:13 -1000 Subject: [PATCH 307/608] Mark DNSOutgoing write functions as protected (#633) --- zeroconf/_dns.py | 84 +++++++++++++++++++++++------------------------- 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index dcb8c9a3..41daee4e 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -572,7 +572,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: self.multicast = multicast self.packets_data: List[bytes] = [] - # these 3 are per-packet -- see also reset_for_next_packet() + # these 3 are per-packet -- see also _reset_for_next_packet() self.names: Dict[str, int] = {} self.data: List[bytes] = [] self.size: int = 12 @@ -585,7 +585,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: self.authorities: List[DNSPointer] = [] self.additionals: List[DNSRecord] = [] - def reset_for_next_packet(self) -> None: + def _reset_for_next_packet(self) -> None: self.names = {} self.data = [] self.size = 12 @@ -686,29 +686,29 @@ def add_question_or_all_cache( for cached_entry in cached_entries: self.add_answer_at_time(cached_entry, now) - def pack(self, format_: Union[bytes, str], value: Any) -> None: + def _pack(self, format_: Union[bytes, str], value: Any) -> None: self.data.append(struct.pack(format_, value)) self.size += struct.calcsize(format_) - def write_byte(self, value: int) -> None: + def _write_byte(self, value: int) -> None: """Writes a single byte to the packet""" - self.pack(b'!c', int2byte(value)) + self._pack(b'!c', int2byte(value)) - def insert_short_at_start(self, value: int) -> None: + def _insert_short_at_start(self, value: int) -> None: """Inserts an unsigned short at the start of the packet""" self.data.insert(0, struct.pack(b'!H', value)) - def replace_short(self, index: int, value: int) -> None: + def _replace_short(self, index: int, value: int) -> None: """Replaces an unsigned short in a certain position in the packet""" self.data[index] = struct.pack(b'!H', value) def write_short(self, value: int) -> None: """Writes an unsigned short to the packet""" - self.pack(b'!H', value) + self._pack(b'!H', value) - def write_int(self, value: Union[float, int]) -> None: + def _write_int(self, value: Union[float, int]) -> None: """Writes an unsigned integer to the packet""" - self.pack(b'!I', int(value)) + self._pack(b'!I', int(value)) def write_string(self, value: bytes) -> None: """Writes a string to the packet""" @@ -716,13 +716,13 @@ def write_string(self, value: bytes) -> None: self.data.append(value) self.size += len(value) - def write_utf(self, s: str) -> None: + def _write_utf(self, s: str) -> None: """Writes a UTF-8 string of a given length to the packet""" utfstr = s.encode('utf-8') length = len(utfstr) if length > 64: raise NamePartTooLongException - self.write_byte(length) + self._write_byte(length) self.write_string(utfstr) def write_character_string(self, value: bytes) -> None: @@ -730,7 +730,7 @@ def write_character_string(self, value: bytes) -> None: length = len(value) if length > 256: raise NamePartTooLongException - self.write_byte(length) + self._write_byte(length) self.write_string(value) def write_name(self, name: str) -> None: @@ -768,49 +768,45 @@ def write_name(self, name: str) -> None: # write the new names out. for part in parts[:count]: - self.write_utf(part) + self._write_utf(part) # if we wrote part of the name, create a pointer to the rest if count != len(name_suffices): # Found substring in packet, create pointer index = self.names[name_suffices[count]] - self.write_byte((index >> 8) | 0xC0) - self.write_byte(index & 0xFF) + self._write_byte((index >> 8) | 0xC0) + self._write_byte(index & 0xFF) else: # this is the end of a name - self.write_byte(0) + self._write_byte(0) - def write_question(self, question: DNSQuestion) -> bool: + def _write_question(self, question: DNSQuestion) -> bool: """Writes a question to the packet""" start_data_length, start_size = len(self.data), self.size self.write_name(question.name) self.write_short(question.type) - self.write_record_class(question) + self._write_record_class(question) return self._check_data_limit_or_rollback(start_data_length, start_size) - def write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: + def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: """Write out the record class including the unique/unicast (QU) bit.""" if record.unique and self.multicast: self.write_short(record.class_ | _CLASS_UNIQUE) else: self.write_short(record.class_) - def write_record(self, record: DNSRecord, now: float) -> bool: + def _write_record(self, record: DNSRecord, now: float) -> bool: """Writes a record (answer, authoritative answer, additional) to - the packet. Returns True on success, or False if we did not (either - because the packet was already finished or because the record does - not fit.""" - if self.state == self.State.finished: - return False - + the packet. Returns True on success, or False if we did not + because the packet because the record does not fit.""" start_data_length, start_size = len(self.data), self.size self.write_name(record.name) self.write_short(record.type) - self.write_record_class(record) + self._write_record_class(record) if now == 0: - self.write_int(record.ttl) + self._write_int(record.ttl) else: - self.write_int(record.get_remaining_ttl(now)) + self._write_int(record.get_remaining_ttl(now)) index = len(self.data) self.write_short(0) # Will get replaced with the actual size @@ -819,7 +815,7 @@ def write_record(self, record: DNSRecord, now: float) -> bool: length = sum((len(d) for d in self.data[index + 1 :])) # Here we replace the 0 length short we wrote # before with the actual length - self.replace_short(index, length) + self._replace_short(index, length) return self._check_data_limit_or_rollback(start_data_length, start_size) def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: @@ -844,7 +840,7 @@ def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) def _write_questions_from_offset(self, questions_offset: int) -> int: questions_written = 0 for question in self.questions[questions_offset:]: - if not self.write_question(question): + if not self._write_question(question): break questions_written += 1 return questions_written @@ -852,7 +848,7 @@ def _write_questions_from_offset(self, questions_offset: int) -> int: def _write_answers_from_offset(self, answer_offset: int) -> int: answers_written = 0 for answer, time_ in self.answers[answer_offset:]: - if not self.write_record(answer, time_): + if not self._write_record(answer, time_): break answers_written += 1 return answers_written @@ -860,7 +856,7 @@ def _write_answers_from_offset(self, answer_offset: int) -> int: def _write_authorities_from_offset(self, authority_offset: int) -> int: authorities_written = 0 for authority in self.authorities[authority_offset:]: - if not self.write_record(authority, 0): + if not self._write_record(authority, 0): break authorities_written += 1 return authorities_written @@ -868,7 +864,7 @@ def _write_authorities_from_offset(self, authority_offset: int) -> int: def _write_additionals_from_offset(self, additional_offset: int) -> int: additionals_written = 0 for additional in self.additionals[additional_offset:]: - if not self.write_record(additional, 0): + if not self._write_record(additional, 0): break additionals_written += 1 return additionals_written @@ -928,10 +924,10 @@ def packets(self) -> List[bytes]: authorities_written = self._write_authorities_from_offset(authority_offset) additionals_written = self._write_additionals_from_offset(additional_offset) - self.insert_short_at_start(additionals_written) - self.insert_short_at_start(authorities_written) - self.insert_short_at_start(answers_written) - self.insert_short_at_start(questions_written) + self._insert_short_at_start(additionals_written) + self._insert_short_at_start(authorities_written) + self._insert_short_at_start(answers_written) + self._insert_short_at_start(questions_written) questions_offset += questions_written answer_offset += answers_written @@ -950,17 +946,17 @@ def packets(self) -> List[bytes]: ): # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 log.debug("Setting TC flag") - self.insert_short_at_start(self.flags | _FLAGS_TC) + self._insert_short_at_start(self.flags | _FLAGS_TC) else: - self.insert_short_at_start(self.flags) + self._insert_short_at_start(self.flags) if self.multicast: - self.insert_short_at_start(0) + self._insert_short_at_start(0) else: - self.insert_short_at_start(self.id) + self._insert_short_at_start(self.id) self.packets_data.append(b''.join(self.data)) - self.reset_for_next_packet() + self._reset_for_next_packet() if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) From a0977a1ddfd7a7a1abcf74c1d90c18021aebc910 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:21:10 -1000 Subject: [PATCH 308/608] Clear cache in ZeroconfServiceTypes tests to ensure responses can be mcast before the timeout (#634) - We prevent the same record from being multicast within 1s because of RFC6762 sec 14. Since these test timeout after 0.5s, the answers they are looking for many be suppressed. Since a legitimate querier will retry again later, we need to clear the cache to simulate that the record has not been multicast recently --- tests/services/test_types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index e8e9911f..9d681667 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -35,7 +35,7 @@ def test_integration_with_listener(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) - + _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types @@ -68,7 +68,7 @@ def test_integration_with_listener_v6_records(self): addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) zeroconf_registrar.register_service(info) - + _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types @@ -100,7 +100,7 @@ def test_integration_with_listener_ipv6(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) - + _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) assert type_ in service_types @@ -132,7 +132,7 @@ def test_integration_with_subtype_and_listener(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) - + _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert discovery_type in service_types From c854d03efd31e1d002518a43221b347fa6ca5de5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:23:04 -1000 Subject: [PATCH 309/608] Update changelog (#635) --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index e062483a..def634de 100644 --- a/README.rst +++ b/README.rst @@ -184,6 +184,11 @@ Changelog should not be used since it will end up missing data, it has been removed +* BREAKING CHANGE: Mark DNSOutgoing write functions as protected (#633) @bdraco + + These functions are not intended to be used by external + callers and the API is not likely to be stable in the future + * TRAFFIC REDUCTION: Add support for handling QU questions (#621) @bdraco Implements RFC 6762 sec 5.4: @@ -214,6 +219,8 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Return early in the shutdown/close process (#632) @bdraco + * Remove unreachable cache check for DNSAddresses (#629) @bdraco The ServiceBrowser would check to see if a DNSAddress was From bbbbddf40d78dbd62a84f2439763d0a59211c5b9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:56:59 -1000 Subject: [PATCH 310/608] Ensure eventloop shutdown is threadsafe (#636) - Prevent ConnectionResetError from being thrown on Windows with ProactorEventLoop on cpython 3.8+ --- zeroconf/_core.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 577da9b8..eae2f49d 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -21,6 +21,7 @@ """ import asyncio +import contextlib import errno import itertools import platform @@ -145,16 +146,22 @@ async def _async_cache_cleanup(self) -> None: self.zc.record_manager.updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) + async def _async_stop_cleanup_task(self) -> None: + """Stop the cleanup task.""" + assert self._cache_cleanup_task is not None + self._cache_cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._cache_cleanup_task + self._cache_cleanup_task = None + def close(self) -> None: """Close the engine.""" for transport in itertools.chain(self.senders, self.readers): transport.close() for s in self._respond_sockets: s.close() - if not self._cache_cleanup_task: - return - self._cache_cleanup_task.cancel() - self._cache_cleanup_task = None + assert self.loop is not None + asyncio.run_coroutine_threadsafe(self._async_stop_cleanup_task(), self.loop).result() class AsyncListener(asyncio.Protocol, QuietLogger): From 09c18a4173a013e67da5a1cdc7089452ba6f67ee Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 14 Jun 2021 22:58:33 -1000 Subject: [PATCH 311/608] Update changelog (#637) --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index def634de..07a9b2d6 100644 --- a/README.rst +++ b/README.rst @@ -219,6 +219,8 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Ensure eventloop shutdown is threadsafe (#636) @bdraco + * Return early in the shutdown/close process (#632) @bdraco * Remove unreachable cache check for DNSAddresses (#629) @bdraco From ce6912a75392cde41d8950b224ba3d14460993ff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 00:42:18 -1000 Subject: [PATCH 312/608] Ensure AsyncZeroconf.async_close can be called multiple times like Zeroconf.close (#638) --- tests/test_aio.py | 18 ++++++++++++++++++ zeroconf/_core.py | 22 ++++++++++++++++------ zeroconf/aio.py | 11 ++++------- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 2b222242..f4fdde4c 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -43,6 +43,14 @@ async def test_async_basic_usage() -> None: await aiozc.async_close() +@pytest.mark.asyncio +async def test_async_close_twice() -> None: + """Test we can close twice.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.async_close() + await aiozc.async_close() + + @pytest.mark.asyncio async def test_async_with_sync_passed_in() -> None: """Test we can create and close the instance when passing in a sync Zeroconf.""" @@ -52,6 +60,16 @@ async def test_async_with_sync_passed_in() -> None: await aiozc.async_close() +@pytest.mark.asyncio +async def test_async_with_sync_passed_in_closed_in_async() -> None: + """Test caller closes the sync version in async.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(zc=zc) + assert aiozc.zeroconf is zc + zc.close() + await aiozc.async_close() + + @pytest.mark.asyncio async def test_async_service_registration() -> None: """Test registering services broadcasts the registration by default.""" diff --git a/zeroconf/_core.py b/zeroconf/_core.py index eae2f49d..89f6dd95 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,22 +146,31 @@ async def _async_cache_cleanup(self) -> None: self.zc.record_manager.updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) - async def _async_stop_cleanup_task(self) -> None: - """Stop the cleanup task.""" + async def _async_close(self) -> None: + """Cancel and wait for the cleanup task to finish.""" + self._async_shutdown() assert self._cache_cleanup_task is not None self._cache_cleanup_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._cache_cleanup_task self._cache_cleanup_task = None + await asyncio.sleep(0) # flush out any call soons - def close(self) -> None: - """Close the engine.""" + def _async_shutdown(self) -> None: + """Shutdown transports and sockets.""" for transport in itertools.chain(self.senders, self.readers): transport.close() for s in self._respond_sockets: s.close() + + def close(self) -> None: + """Close from sync context.""" assert self.loop is not None - asyncio.run_coroutine_threadsafe(self._async_stop_cleanup_task(), self.loop).result() + # Guard against Zeroconf.close() being called from the eventloop + if get_running_loop() == self.loop: + self._async_shutdown() + return + asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result() class AsyncListener(asyncio.Protocol, QuietLogger): @@ -355,7 +364,8 @@ def add_notify_listener(self, listener: NotifyListener) -> None: def remove_notify_listener(self, listener: NotifyListener) -> None: """Removes a listener from the set that is currently listening.""" - self._notify_listeners.remove(listener) + with contextlib.suppress(ValueError): + self._notify_listeners.remove(listener) def add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 6f445e72..626b75f8 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -264,17 +264,14 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: self.zeroconf.registry.update(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - def _close(self) -> None: - """Shutdown zeroconf and the sender.""" - self.zeroconf.remove_notify_listener(self.async_notify) - self.zeroconf.close() - async def async_close(self) -> None: """Ends the background threads, and prevent this instance from servicing further queries.""" - await self.zeroconf.async_wait_for_start() + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) await self.async_remove_all_service_listeners() - await self.loop.run_in_executor(None, self._close) + self.zeroconf.remove_notify_listener(self.async_notify) + await self.loop.run_in_executor(None, self.zeroconf.close) async def async_get_service_info( self, type_: str, name: str, timeout: int = 3000 From 5ebd95452b16e76c37649486b232856a80390ac3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 00:50:25 -1000 Subject: [PATCH 313/608] Ensure cache is cleared before starting known answer enumeration query test (#639) --- tests/test_handlers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 71f6aff2..1e8109eb 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -593,6 +593,7 @@ def test_known_answer_supression_service_type_enumeration_query(): ) zc.register_service(info) now = current_time_millis() + _clear_cache(zc) # Test PTR supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) From 330e36ceb4202c579fe979958c63c37033ababbb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 08:29:54 -1000 Subject: [PATCH 314/608] Ensure the ServiceInfo.key gets updated when the name is changed externally (#645) --- tests/test_services.py | 19 +++++++++++++++++++ zeroconf/_services/__init__.py | 13 ++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_services.py b/tests/test_services.py index db1f7579..867c546a 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1239,3 +1239,22 @@ def mock_incoming_msg(records) -> r.DNSIncoming: browser.cancel() zc.close() + + +def test_changing_name_updates_serviceinfo_key(): + """Verify a name change will adjust the underlying key value.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + assert info_service.key == "mytesthome._homeassistant._tcp.local." + info_service.name = "YourTestHome._homeassistant._tcp.local." + assert info_service.key == "yourtesthome._homeassistant._tcp.local." diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 8a6d7d06..857197bd 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -466,7 +466,7 @@ def __init__( if not type_.endswith(service_type_name(name, strict=False)): raise BadTypeInNameException self.type = type_ - self.name = name + self._name = name self.key = name.lower() if addresses is not None: self._addresses = addresses @@ -494,6 +494,17 @@ def __init__( self.host_ttl = host_ttl self.other_ttl = other_ttl + @property + def name(self) -> str: + """The name of the service.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Replace the the name and reset the key.""" + self._name = name + self.key = name.lower() + @property def addresses(self) -> List[bytes]: """IPv4 addresses of this service. From 9354ab39f350e4e6451dc4965225591761ada40d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 08:33:29 -1000 Subject: [PATCH 315/608] Add missing coverage to ServiceRegistry (#646) --- zeroconf/_services/registry.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 4c4c1706..e9db74f1 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -86,18 +86,13 @@ def get_infos_server(self, server: str) -> List[ServiceInfo]: def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" - service_infos = [] - - for name in getattr(self, attr).get(key.lower(), [])[:]: - info = self._services.get(name) - # Since we do not get under a lock since it would be - # a performance issue, its possible - # the service can be unregistered during the get - # so we must check if info is None - if info is not None: - service_infos.append(info) - - return service_infos + # Since we do not get under a lock since it would be + # a performance issue, its possible + # the service can be unregistered during the get + # so we must check if info is None + return list( + filter(None, [self._services.get(name) for name in getattr(self, attr).get(key.lower(), [])[:]]) + ) def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" From a83d390bef042da51d93014c222c65af81723a20 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 08:38:30 -1000 Subject: [PATCH 316/608] Use ServiceInfo.key/ServiceInfo.server_key instead of lowering in ServiceRegistry (#647) --- zeroconf/_services/registry.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index e9db74f1..058717ce 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -96,18 +96,16 @@ def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" - lower_name = info.name.lower() - if lower_name in self._services: + if info.key in self._services: raise ServiceNameAlreadyRegistered - self._services[lower_name] = info - self.types.setdefault(info.type.lower(), []).append(lower_name) - self.servers.setdefault(info.server.lower(), []).append(lower_name) + self._services[info.key] = info + self.types.setdefault(info.type.lower(), []).append(info.key) + self.servers.setdefault(info.server_key, []).append(info.key) def _remove(self, info: ServiceInfo) -> None: """Remove a service under the lock.""" - lower_name = info.name.lower() - old_service_info = self._services[lower_name] - self.types[old_service_info.type.lower()].remove(lower_name) - self.servers[old_service_info.server.lower()].remove(lower_name) - del self._services[lower_name] + old_service_info = self._services[info.key] + self.types[old_service_info.type.lower()].remove(info.key) + self.servers[old_service_info.server_key].remove(info.key) + del self._services[info.key] From cf0b5b9e2cfa4779425401b3d205f5d913621864 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 08:45:30 -1000 Subject: [PATCH 317/608] Ensure services are removed from the registry when calling unregister_all_services (#644) - There was a race condition where a query could be answered for a service in the registry while goodbye packets which could result a fresh record being broadcast after the goodbye if a query came in at just the right time. To avoid this, we now remove the services from the registry right after we generate the goodbye packet --- tests/test_core.py | 29 +++++++++++++++++++++++++++++ tests/test_init.py | 17 +++++++++++++++-- zeroconf/_core.py | 21 +++++++++++++++------ zeroconf/_services/registry.py | 23 ++++++++++++----------- 4 files changed, 71 insertions(+), 19 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 97457592..6252488b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -317,3 +317,32 @@ def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): time.sleep(0.2) zc.close() assert zc.cache.get(entry) is not None + + +def test_goodbye_all_services(): + """Verify generating the goodbye query does not change with time.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + out = zc.generate_unregister_all_services() + assert out is None + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info) + out = zc.generate_unregister_all_services() + assert out is not None + first_packet = out.packets() + zc.registry.add(info) + out2 = zc.generate_unregister_all_services() + assert out2 is not None + second_packet = out.packets() + assert second_packet == first_packet + + # Verify the registery is empty + out3 = zc.generate_unregister_all_services() + assert out3 is None + assert zc.registry.get_service_infos() == [] + + zc.close() diff --git a/tests/test_init.py b/tests/test_init.py index 6e5457ff..3cc16b22 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -184,8 +184,21 @@ def verify_name_change(self, zc, type_, name, number_hosts): # verify no name conflict https://tools.ietf.org/html/rfc6762#section-6.6 zc.register_service(info_service, cooperating_responders=True) - zc.register_service(info_service, allow_name_change=True) - assert info_service.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) + # Create a new object since allow_name_change will mutate the + # original object and then we will have the wrong service + # in the registry + info_service2 = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zc.register_service(info_service2, allow_name_change=True) + assert info_service2.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) def generate_many_hosts(self, zc, type_, name, number_hosts): records_per_server = 2 diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 89f6dd95..c8e7736e 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -476,10 +476,22 @@ def unregister_service(self, info: ServiceInfo) -> None: self.registry.remove(info) self._broadcast_service(info, _UNREGISTER_TIME, 0) - def unregister_all_services(self) -> None: - """Unregister all registered services.""" + def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: + """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" service_infos = self.registry.get_service_infos() if not service_infos: + return None + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) + for info in service_infos: + self._add_broadcast_answer(out, info, 0) + self.registry.remove(service_infos) + return out + + def unregister_all_services(self) -> None: + """Unregister all registered services.""" + # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + out = self.generate_unregister_all_services() + if not out: return now = current_time_millis() next_time = now @@ -489,9 +501,6 @@ def unregister_all_services(self) -> None: self.wait(next_time - now) now = current_time_millis() continue - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - for info in service_infos: - self._add_broadcast_answer(out, info, 0) self.send(out) i += 1 next_time += _UNREGISTER_TIME @@ -604,8 +613,8 @@ def close(self) -> None: if self._GLOBAL_DONE: return # remove service listeners - self.remove_all_service_listeners() self.unregister_all_services() + self.remove_all_service_listeners() self._GLOBAL_DONE = True self.engine.close() # shutdown the rest diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 058717ce..244ff294 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -21,7 +21,7 @@ """ import threading -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from .._exceptions import ServiceNameAlreadyRegistered @@ -47,21 +47,21 @@ def __init__( def add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" - with self._lock: self._add(info) - def remove(self, info: ServiceInfo) -> None: + def remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: """Remove a new service from the registry.""" + infos = info if isinstance(info, list) else [info] with self._lock: - self._remove(info) + self._remove(infos) def update(self, info: ServiceInfo) -> None: """Update new service in the registry.""" with self._lock: - self._remove(info) + self._remove([info]) self._add(info) def get_service_infos(self) -> List[ServiceInfo]: @@ -103,9 +103,10 @@ def _add(self, info: ServiceInfo) -> None: self.types.setdefault(info.type.lower(), []).append(info.key) self.servers.setdefault(info.server_key, []).append(info.key) - def _remove(self, info: ServiceInfo) -> None: - """Remove a service under the lock.""" - old_service_info = self._services[info.key] - self.types[old_service_info.type.lower()].remove(info.key) - self.servers[old_service_info.server_key].remove(info.key) - del self._services[info.key] + def _remove(self, infos: List[ServiceInfo]) -> None: + """Remove a services under the lock.""" + for info in infos: + old_service_info = self._services[info.key] + self.types[old_service_info.type.lower()].remove(info.key) + self.servers[old_service_info.server_key].remove(info.key) + del self._services[info.key] From 79e39c0e923a1f6d87353761809f34f0fe1f0800 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 09:02:46 -1000 Subject: [PATCH 318/608] Use cache clear helper in aio tests (#648) --- tests/test_aio.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index f4fdde4c..49aa7ef2 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -19,6 +19,9 @@ from zeroconf._utils.time import current_time_millis +from . import _clear_cache + + @pytest.fixture(autouse=True) def verify_threads_ended(): """Verify that the threads are not running after the test.""" @@ -387,10 +390,7 @@ async def test_service_info_async_request() -> None: assert aiosinfos[1].addresses == [socket.inet_aton("10.0.1.5")] aiosinfo = AsyncServiceInfo(type_, registration_name) - zc_cache = aiozc.zeroconf.cache - for name in zc_cache.names(): - for record in zc_cache.entries_with_name(name): - zc_cache.remove(record) + _clear_cache(aiozc.zeroconf) # Generating the race condition is almost impossible # without patching since its a TOCTOU race with unittest.mock.patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): From 72e709b40caed016ba981be3752c439bbbf40ec7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 09:18:24 -1000 Subject: [PATCH 319/608] Add async_unregister_all_services to AsyncZeroconf (#649) --- tests/test_aio.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++ zeroconf/aio.py | 15 ++++++++++++ 2 files changed, 75 insertions(+) diff --git a/tests/test_aio.py b/tests/test_aio.py index 49aa7ef2..388668b2 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -18,6 +18,8 @@ from zeroconf._services import ServiceInfo, ServiceListener from zeroconf._utils.time import current_time_millis +from . import _clear_cache + from . import _clear_cache @@ -498,3 +500,61 @@ async def test_async_context_manager() -> None: await task aiosinfo = await aiozc.async_get_service_info(type_, registration_name) assert aiosinfo is not None + + +@pytest.mark.asyncio +async def test_async_unregister_all_services() -> None: + """Test unregistering all services.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "abc" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name2, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-1.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-5.local.", + addresses=[socket.inet_aton("10.0.1.5")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + tasks = [] + tasks.append(aiozc.async_get_service_info(type_, registration_name)) + tasks.append(aiozc.async_get_service_info(type_, registration_name2)) + results = await asyncio.gather(*tasks) + assert results[0] is not None + assert results[1] is not None + + await aiozc.async_unregister_all_services() + + tasks = [] + tasks.append(aiozc.async_get_service_info(type_, registration_name)) + tasks.append(aiozc.async_get_service_info(type_, registration_name2)) + results = await asyncio.gather(*tasks) + assert results[0] is None + assert results[1] is None + + # Verify we can call again + await aiozc.async_unregister_all_services() + + await aiozc.async_close() diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 626b75f8..446f6cf0 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -227,6 +227,21 @@ async def async_register_service( self.zeroconf.registry.add(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + out = self.zeroconf.generate_unregister_all_services() + if not out: + return + for i in range(3): + if i != 0: + await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) + self.zeroconf.async_send(out) + async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: """Checks the network for a unique service name.""" instance_name_from_service_info(info) From df9f8d9a0110cc9135b7c2f0b4cd47e985da9a7e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 09:23:13 -1000 Subject: [PATCH 320/608] Ensure interface_index_to_ip6_address skips ipv4 adapters (#651) --- tests/utils/test_net.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 1a8beebe..1fd4d113 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -46,9 +46,15 @@ def test_interface_index_to_ip6_address(): """Test we can extract from mocked adapters.""" adapters = _generate_mock_adapters() assert netutils.interface_index_to_ip6_address(adapters, 1) == ('2001:db8::', 1, 1) + + # call with invalid adapter with pytest.raises(RuntimeError): assert netutils.interface_index_to_ip6_address(adapters, 6) + # call with adapter that has ipv4 address only + with pytest.raises(RuntimeError): + assert netutils.interface_index_to_ip6_address(adapters, 2) + def test_ip6_addresses_to_indexes(): """Test we can extract from mocked adapters.""" From b940f878fe1f8e6b8dfe2554b781cd6034dee722 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 09:49:22 -1000 Subject: [PATCH 321/608] Set __all__ in zeroconf.aio to ensure private functions do now show in the docs (#652) --- zeroconf/aio.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 446f6cf0..974a2cb9 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -33,6 +33,14 @@ from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME +__all__ = [ + "AsyncZeroconf", + "AsyncServiceInfo", + "AsyncServiceBrowser", + "AsyncServiceListener", +] + + class AsyncNotifyListener(NotifyListener): """A NotifyListener that async code can use to wait for events.""" From 7d8994bc3cb4d5978bb1ff189bb5a4b7c81b5c4c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 09:56:51 -1000 Subject: [PATCH 322/608] Remove all calls to the executor in AsyncZeroconf (#653) --- zeroconf/_core.py | 50 +++++++++++++++++++++++++++++++++++------------ zeroconf/aio.py | 3 ++- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index c8e7736e..e859c7b1 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -149,11 +149,11 @@ async def _async_cache_cleanup(self) -> None: async def _async_close(self) -> None: """Cancel and wait for the cleanup task to finish.""" self._async_shutdown() - assert self._cache_cleanup_task is not None - self._cache_cleanup_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._cache_cleanup_task - self._cache_cleanup_task = None + if self._cache_cleanup_task: + self._cache_cleanup_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._cache_cleanup_task + self._cache_cleanup_task = None await asyncio.sleep(0) # flush out any call soons def _async_shutdown(self) -> None: @@ -170,6 +170,8 @@ def close(self) -> None: if get_running_loop() == self.loop: self._async_shutdown() return + if not self.loop.is_running(): + return asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result() @@ -607,17 +609,15 @@ def async_send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _ # on send errors, log the exception and keep going self.log_exception_warning('Error sending through socket %d', s.fileno()) - def close(self) -> None: - """Ends the background threads, and prevent this instance from - servicing further queries.""" + def _close(self) -> None: + """Set global done and remove all service listeners.""" if self._GLOBAL_DONE: return - # remove service listeners - self.unregister_all_services() self.remove_all_service_listeners() self._GLOBAL_DONE = True - self.engine.close() - # shutdown the rest + + def _shutdown_threads(self) -> None: + """Shutdown any threads.""" self.notify_all() if not self._loop_thread: return @@ -625,6 +625,32 @@ def close(self) -> None: self.loop.call_soon_threadsafe(self.loop.stop) self._loop_thread.join() + def close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + """ + self.unregister_all_services() + self._close() + self.engine.close() + self._shutdown_threads() + + async def _async_close(self) -> None: + """Ends the background threads, and prevent this instance from + servicing further queries. + + This method is idempotent and irreversible. + + This call only intended to be used by AsyncZeroconf + + Callers are responsible for unregistering all services + before calling this function + """ + self._close() + await self.engine._async_close() # pylint: disable=protected-access + self._shutdown_threads() + def __enter__(self) -> 'Zeroconf': return self diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 974a2cb9..8fff385f 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -294,7 +294,8 @@ async def async_close(self) -> None: await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) await self.async_remove_all_service_listeners() self.zeroconf.remove_notify_listener(self.async_notify) - await self.loop.run_in_executor(None, self.zeroconf.close) + await self.async_unregister_all_services() + await self.zeroconf._async_close() # pylint: disable=protected-access async def async_get_service_info( self, type_: str, name: str, timeout: int = 3000 From 3c61d03f5954c3e45229d6c1399a63c0f7331d55 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 10:02:43 -1000 Subject: [PATCH 323/608] Add test coverage for normalize_interface_choice exception paths (#654) --- tests/utils/test_net.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 1fd4d113..7890f381 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -67,6 +67,17 @@ def test_ip6_addresses_to_indexes(): assert netutils.ip6_addresses_to_indexes(interfaces) == [(('2001:db8::', 1, 1), 1)] +def test_normalize_interface_choice_errors(): + """Test we generate exception on invalid input.""" + with patch("zeroconf._utils.net.get_all_addresses", return_value=[]), patch( + "zeroconf._utils.net.get_all_addresses_v6", return_value=[] + ), pytest.raises(RuntimeError): + netutils.normalize_interface_choice(r.InterfaceChoice.All) + + with pytest.raises(TypeError): + netutils.normalize_interface_choice("1.2.3.4") + + @pytest.mark.parametrize( "errno,expected_result", [(errno.EADDRINUSE, False), (errno.EADDRNOTAVAIL, False), (errno.EINVAL, False), (0, True)], From efd6bfbe81f448da2ee68b91d49cbe1982271da3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 10:05:17 -1000 Subject: [PATCH 324/608] Improve aio utils tests to validate high lock contention (#655) --- tests/utils/test_aio.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index 65eaf255..1f0a1d7e 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -37,8 +37,12 @@ async def _hold_condition(): task = asyncio.ensure_future(_hold_condition()) await asyncio.sleep(0.1) - async with test_cond: - await aioutils.wait_condition_or_timeout(test_cond, 0.1) + async def _async_wait_or_timeout(): + async with test_cond: + await aioutils.wait_condition_or_timeout(test_cond, 0.1) + + # Test high lock contention + await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)]) task.cancel() with contextlib.suppress(asyncio.CancelledError): From 87fe529a33b920532b2af688bb66182ae832a3ad Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 10:15:08 -1000 Subject: [PATCH 325/608] Add coverage for registering a service with a custom ttl (#656) --- tests/test_core.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 6252488b..9c04d03d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -346,3 +346,28 @@ def test_goodbye_all_services(): assert zc.registry.get_service_infos() == [] zc.close() + + +def test_register_service_with_custom_ttl(): + """Test a registering a service with a custom ttl.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = r.ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service, ttl=30) + assert zc.cache.get(info_service.dns_pointer()).ttl == 30 + zc.close() From 5752ace7727bffa34cdac0455125a941014ab123 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 10:23:33 -1000 Subject: [PATCH 326/608] Add test for Zeroconf.get_service_info failure case (#657) --- tests/test_core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 9c04d03d..97799d95 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -371,3 +371,10 @@ def test_register_service_with_custom_ttl(): zc.register_service(info_service, ttl=30) assert zc.cache.get(info_service.dns_pointer()).ttl == 30 zc.close() + + +def test_get_service_info_failure_path(): + """Verify get_service_info return None when the underlying call returns False.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None + zc.close() From 0e52be059065e23ebe9e11c465adc20655b6080e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 11:14:38 -1000 Subject: [PATCH 327/608] Add test for launching with apple_p2p=True (#660) - Switch to using `sys.platform` to detect Mac instead of `platform.system()` since `platform.system()` is not intended to be machine parsable and is only for humans. Closes #650 --- tests/test_core.py | 11 +++++++++++ zeroconf/_core.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 97799d95..19ab81d1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -9,6 +9,7 @@ import os import pytest import socket +import sys import time import unittest import unittest.mock @@ -99,6 +100,16 @@ def test_launch_and_close_v6_only(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) rv.close() + @unittest.skipIf(sys.platform != 'darwin', reason="apple_p2p only available on mac") + def test_launch_and_close_apple_p2p(self): + rv = r.Zeroconf(apple_p2p=True) + rv.close() + + @unittest.skipIf(sys.platform == 'darwin', reason="apple_p2p available on mac") + def test_launch_and_close_apple_p2p(self): + with pytest.raises(RuntimeError): + r.Zeroconf(apple_p2p=True) + def test_handle_response(self): def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: ttl = 120 diff --git a/zeroconf/_core.py b/zeroconf/_core.py index e859c7b1..21bc9a6a 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -24,8 +24,8 @@ import contextlib import errno import itertools -import platform import socket +import sys import threading from types import TracebackType # noqa # used in type hints from typing import Dict, List, Optional, Tuple, Type, Union, cast @@ -283,7 +283,7 @@ def __init__( # hook for threads self._GLOBAL_DONE = False - if apple_p2p and not platform.system() == 'Darwin': + if apple_p2p and sys.platform != 'darwin': raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) From 72db0c10246e948c15d9a53f60a54b835ccc67bc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 13:29:39 -1000 Subject: [PATCH 328/608] Fix flakey ZeroconfServiceTypes types test (#662) --- tests/services/test_types.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 9d681667..8cff4431 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -8,6 +8,7 @@ import unittest import socket import sys +import time import zeroconf as r from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes @@ -35,6 +36,8 @@ def test_integration_with_listener(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) + # Ensure we do not clear the cache until after the last broadcast is processed + time.sleep(0.2) _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) @@ -68,6 +71,8 @@ def test_integration_with_listener_v6_records(self): addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) zeroconf_registrar.register_service(info) + # Ensure we do not clear the cache until after the last broadcast is processed + time.sleep(0.2) _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) @@ -100,6 +105,8 @@ def test_integration_with_listener_ipv6(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) + # Ensure we do not clear the cache until after the last broadcast is processed + time.sleep(0.2) _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) @@ -132,6 +139,8 @@ def test_integration_with_subtype_and_listener(self): addresses=[socket.inet_aton("10.0.1.2")], ) zeroconf_registrar.register_service(info) + # Ensure we do not clear the cache until after the last broadcast is processed + time.sleep(0.2) _clear_cache(zeroconf_registrar) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) From aaf8a368063f080be4a9c01fe671243e63bdf576 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:04:06 -1000 Subject: [PATCH 329/608] Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.aio (#658) --- tests/test_aio.py | 50 +++++++++++++++++++++++++++++++++++++++++-- zeroconf/aio.py | 54 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 100 insertions(+), 4 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 388668b2..47c1e2d9 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -5,13 +5,14 @@ """Unit tests for aio.py.""" import asyncio +import logging import socket import threading import unittest.mock import pytest -from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf +from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf, AsyncZeroconfServiceTypes from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered @@ -20,8 +21,19 @@ from . import _clear_cache +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) -from . import _clear_cache + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) @pytest.fixture(autouse=True) @@ -558,3 +570,37 @@ async def test_async_unregister_all_services() -> None: await aiozc.async_unregister_all_services() await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_zeroconf_service_types(): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + task = await zeroconf_registrar.async_register_service(info) + await task + # Ensure we do not clear the cache until after the last broadcast is processed + await asyncio.sleep(0.2) + _clear_cache(zeroconf_registrar.zeroconf) + try: + service_types = await AsyncZeroconfServiceTypes.async_find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar.zeroconf) + service_types = await AsyncZeroconfServiceTypes.async_find(aiozc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types + + finally: + await zeroconf_registrar.async_close() diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 8fff385f..01211bb4 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -22,15 +22,24 @@ import asyncio import contextlib from types import TracebackType # noqa # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Type, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from ._core import NotifyListener, Zeroconf from ._exceptions import NonUniqueNameException from ._services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info +from ._services.types import ZeroconfServiceTypes from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType from ._utils.time import current_time_millis, millis_to_seconds -from .const import _BROWSER_TIME, _CHECK_TIME, _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _UNREGISTER_TIME +from .const import ( + _BROWSER_TIME, + _CHECK_TIME, + _LISTENER_TIME, + _MDNS_PORT, + _REGISTER_TIME, + _SERVICE_TYPE_ENUMERATION_NAME, + _UNREGISTER_TIME, +) __all__ = [ @@ -38,6 +47,7 @@ "AsyncServiceInfo", "AsyncServiceBrowser", "AsyncServiceListener", + "AsyncZeroconfServiceTypes", ] @@ -137,6 +147,7 @@ async def async_cancel(self) -> None: async def async_run(self) -> None: """Run the browser task.""" self.run() + await self.aiozc.zeroconf.async_wait_for_start() while True: timeout = self._seconds_to_wait() if timeout: @@ -164,6 +175,45 @@ async def async_run(self) -> None: ) +class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): + """An async version of ZeroconfServiceTypes.""" + + @classmethod + async def async_find( + cls, + aiozc: Optional['AsyncZeroconf'] = None, + timeout: Union[int, float] = 5, + interfaces: InterfacesType = InterfaceChoice.All, + ip_version: Optional[IPVersion] = None, + ) -> Tuple[str, ...]: + """ + Return all of the advertised services on any local networks. + + :param aiozc: AsyncZeroconf() instance. Pass in if already have an + instance running or if non-default interfaces are needed + :param timeout: seconds to wait for any responses + :param interfaces: interfaces to listen on. + :param ip_version: IP protocol version to use. + :return: tuple of service type strings + """ + local_zc = aiozc or AsyncZeroconf(interfaces=interfaces, ip_version=ip_version) + listener = cls() + async_browser = AsyncServiceBrowser( + local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener # type: ignore + ) + + # wait for responses + await asyncio.sleep(timeout) + + await async_browser.async_cancel() + + # close down anything we opened + if aiozc is None: + await local_zc.async_close() + + return tuple(sorted(listener.found_services)) + + class AsyncZeroconf: """Implementation of Zeroconf Multicast DNS Service Discovery From e76c7a5b76485efce0929ee8417aa2e0f262c04c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:08:01 -1000 Subject: [PATCH 330/608] Permit the ServiceBrowser to browse overlong types (#666) - At least one type "tivo-videostream" exists in the wild so we are permissive about what we will look for, and strict about what we will announce. Fixes #661 --- tests/test_exceptions.py | 1 - tests/test_services.py | 17 +++++++++++++++++ tests/utils/test_name.py | 26 ++++++++++++++++++++++++++ zeroconf/_utils/name.py | 6 +++++- 4 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 tests/utils/test_name.py diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index cfc4c19d..aa2f74f6 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -60,7 +60,6 @@ def test_bad_service_names(self): '_x-._tcp.local.', '_22._udp.local.', '_2-2._tcp.local.', - '_1234567890-abcde._udp.local.', '\x00._x._udp.local.', ) for name in bad_names_to_try: diff --git a/tests/test_services.py b/tests/test_services.py index 867c546a..b8edc4ac 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1258,3 +1258,20 @@ def test_changing_name_updates_serviceinfo_key(): assert info_service.key == "mytesthome._homeassistant._tcp.local." info_service.name = "YourTestHome._homeassistant._tcp.local." assert info_service.key == "yourtesthome._homeassistant._tcp.local." + + +def test_servicebrowser_uses_non_strict_names(): + """Verify we can look for technically invalid names as we cannot change what others do.""" + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + browser = ServiceBrowser(zc, ["_tivo-videostream._tcp.local."], [on_service_state_change]) + browser.cancel() + + # Still fail on completely invalid + with pytest.raises(r.BadTypeInNameException): + browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change]) + zc.close() diff --git a/tests/utils/test_name.py b/tests/utils/test_name.py new file mode 100644 index 00000000..6f8b417d --- /dev/null +++ b/tests/utils/test_name.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for zeroconf._utils.name.""" + +import pytest + +from zeroconf._utils import name as nameutils +from zeroconf import BadTypeInNameException + + +def test_service_type_name_overlong_type(): + """Test overlong service_type_name type.""" + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.") + nameutils.service_type_name("Tivo1._tivo-videostream._tcp.local.", strict=False) + + +def test_service_type_name_overlong_full_name(): + """Test overlong service_type_name full name.""" + long_name = "Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1Tivo1" * 100 + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.") + with pytest.raises(BadTypeInNameException): + nameutils.service_type_name(f"{long_name}._tivo-videostream._tcp.local.", strict=False) diff --git a/zeroconf/_utils/name.py b/zeroconf/_utils/name.py index 10a0ccf8..c59ac33a 100644 --- a/zeroconf/_utils/name.py +++ b/zeroconf/_utils/name.py @@ -74,6 +74,9 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis :param type_: Type, SubType or service name to validate :return: fully qualified service name (eg: _http._tcp.local.) """ + if len(type_) > 256: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 + raise BadTypeInNameException("Full name (%s) must be > 256 bytes" % type_) if type_.endswith((_TCP_PROTOCOL_LOCAL_TRAILER, _NONTCP_PROTOCOL_LOCAL_TRAILER)): remaining = type_[: -len(_TCP_PROTOCOL_LOCAL_TRAILER)].split('.') @@ -104,7 +107,8 @@ def service_type_name(type_: str, *, strict: bool = True) -> str: # pylint: dis test_service_name = service_name[1:] - if len(test_service_name) > 15: + if strict and len(test_service_name) > 15: + # https://datatracker.ietf.org/doc/html/rfc6763#section-7.2 raise BadTypeInNameException("Service name (%s) must be <= 15 bytes" % test_service_name) if '--' in test_service_name: From 481cc42d000f5b0258f1be3b6df7cb7b24428b7f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:15:42 -1000 Subject: [PATCH 331/608] Update async_browser.py example to use AsyncZeroconfServiceTypes (#665) --- examples/async_browser.py | 46 ++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/examples/async_browser.py b/examples/async_browser.py index 4a3861cb..cba30223 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -8,10 +8,10 @@ import argparse import asyncio import logging -from typing import cast +from typing import Any, Optional, cast from zeroconf import IPVersion, ServiceStateChange -from zeroconf.aio import AsyncServiceBrowser, AsyncZeroconf +from zeroconf.aio import AsyncServiceBrowser, AsyncZeroconf, AsyncZeroconfServiceTypes def async_on_service_state_change( @@ -43,11 +43,39 @@ async def async_display_service_info(zeroconf: AsyncZeroconf, service_type: str, print('\n') +class AsyncRunner: + def __init__(self, args: Any) -> None: + self.args = args + self.aiobrowser: Optional[AsyncServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + + services = ["_http._tcp.local.", "_hap._tcp.local."] + if self.args.find: + services = list( + await AsyncZeroconfServiceTypes.async_find(aiozc=self.aiozc, ip_version=ip_version) + ) + + print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services) + self.aiobrowser = AsyncServiceBrowser(self.aiozc, services, handlers=[async_on_service_state_change]) + while True: + await asyncio.sleep(1) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.aiobrowser is not None + await self.aiobrowser.async_cancel() + await self.aiozc.async_close() + + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true') + parser.add_argument('--find', action='store_true', help='Browse all available services') version_group = parser.add_mutually_exclusive_group() version_group.add_argument('--v6', action='store_true') version_group.add_argument('--v6-only', action='store_true') @@ -62,17 +90,9 @@ async def async_display_service_info(zeroconf: AsyncZeroconf, service_type: str, else: ip_version = IPVersion.V4Only - aiozc = AsyncZeroconf(ip_version=ip_version) - - services = ["_http._tcp.local.", "_hap._tcp.local."] - print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services) - aiobrowser = AsyncServiceBrowser(aiozc, services, handlers=[async_on_service_state_change]) - loop = asyncio.get_event_loop() + runner = AsyncRunner(args) try: - loop.run_forever() + loop.run_until_complete(runner.async_run()) except KeyboardInterrupt: - pass - finally: - loop.run_until_complete(aiobrowser.async_cancel()) - loop.run_until_complete(aiozc.async_close()) + loop.run_until_complete(runner.async_close()) From 75347b4e30429e130716b666da52953700f0f8e9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:25:17 -1000 Subject: [PATCH 332/608] Add missing coverage for ServiceListener (#668) --- tests/test_services.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_services.py b/tests/test_services.py index b8edc4ac..fa343af3 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1275,3 +1275,27 @@ def on_service_state_change(zeroconf, service_type, state_change, name): with pytest.raises(r.BadTypeInNameException): browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change]) zc.close() + + +def test_servicelisteners_raise_not_implemented(): + """Verify service listeners raise when one of the methods is not implemented.""" + + class MyPartialListener(r.ServiceListener): + """A listener that does not implement anything.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(NotImplementedError): + MyPartialListener().add_service( + zc, "_tivo-videostream._tcp.local.", "Tivo1._tivo-videostream._tcp.local." + ) + with pytest.raises(NotImplementedError): + MyPartialListener().remove_service( + zc, "_tivo-videostream._tcp.local.", "Tivo1._tivo-videostream._tcp.local." + ) + with pytest.raises(NotImplementedError): + MyPartialListener().update_service( + zc, "_tivo-videostream._tcp.local.", "Tivo1._tivo-videostream._tcp.local." + ) + + zc.close() From d59fb8be29d8602ad66d89f595b26671a528fd77 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:32:12 -1000 Subject: [PATCH 333/608] Add missing coverage for ServiceInfo address changes (#669) --- tests/test_services.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_services.py b/tests/test_services.py index fa343af3..45a5b91e 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1299,3 +1299,36 @@ class MyPartialListener(r.ServiceListener): ) zc.close() + + +def test_serviceinfo_address_updates(): + """Verify adding/removing/setting addresses on ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + + # Verify addresses and parsed_addresses are mutually exclusive + with pytest.raises(TypeError): + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + parsed_addresses=["10.0.1.2"], + ) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info_service.addresses = [socket.inet_aton("10.0.1.3")] + assert info_service.addresses == [socket.inet_aton("10.0.1.3")] From d274cd3a3409997b764c49d3eae7e8ee2fba33b6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:35:58 -1000 Subject: [PATCH 334/608] Add test for sending unicast responses (#670) --- tests/test_core.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 19ab81d1..8cb04369 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -389,3 +389,30 @@ def test_get_service_info_failure_path(): zc = Zeroconf(interfaces=['127.0.0.1']) assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None zc.close() + + +def test_sending_unicast(): + """Test sending unicast response.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + entry = r.DNSText( + "didnotcrashincoming._crash._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + generated.add_answer_at_time(entry, 0) + zc.send(generated, "2001:db8::1", const._MDNS_PORT) # https://www.iana.org/go/rfc3849 + time.sleep(0.2) + assert zc.cache.get(entry) is None + + zc.send(generated, "198.51.100.0", const._MDNS_PORT) # Documentation (TEST-NET-2) + time.sleep(0.2) + assert zc.cache.get(entry) is None + + zc.send(generated) + time.sleep(0.2) + assert zc.cache.get(entry) is not None + + zc.close() From 8535110dd661ce406904930994a9f86faf897597 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 14:41:26 -1000 Subject: [PATCH 335/608] Add oversized packet to the invalid packet test (#671) --- tests/test_core.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 8cb04369..1e6d0d92 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -308,9 +308,16 @@ def test_invalid_packets_ignored_and_does_not_cause_loop_exception(): parsed = r.DNSIncoming(packet) assert parsed.valid is False + # Invalid Packet mock_out = unittest.mock.Mock() mock_out.packets = lambda: [packet] zc.send(mock_out) + + # Invalid oversized packet + mock_out = unittest.mock.Mock() + mock_out.packets = lambda: [packet * 1000] + zc.send(mock_out) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) entry = r.DNSText( "didnotcrashincoming._crash._tcp.local.", From ba2a4f960d0f9478198968a1466a8b48c963b772 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 16:22:52 -1000 Subject: [PATCH 336/608] Make calculation of times in DNSRecord lazy (#676) - Most of the time we only check one of the time attrs or none at all. Wait to calculate them until they are requested. --- zeroconf/_dns.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 41daee4e..073b95f7 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -146,9 +146,9 @@ def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) - super().__init__(name, type_, class_) self.ttl = ttl self.created = current_time_millis() - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) + self._expiration_time: Optional[float] = None + self._stale_time: Optional[float] = None + self._recent_time: Optional[float] = None def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" @@ -157,10 +157,7 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use def suppressed_by(self, msg: 'DNSIncoming') -> bool: """Returns true if any answer in a message can suffice for the information held in this record.""" - for record in msg.answers: - if self.suppressed_by_answer(record): - return True - return False + return any(self.suppressed_by_answer(record) for record in msg.answers) def suppressed_by_answer(self, other: 'DNSRecord') -> bool: """Returns true if another record has same name, type and class, @@ -175,18 +172,26 @@ def get_expiration_time(self, percent: int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" + if self._expiration_time is None: + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) return max(0, millis_to_seconds(self._expiration_time - now)) def is_expired(self, now: float) -> bool: """Returns true if this record has expired.""" + if self._expiration_time is None: + self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) return self._expiration_time <= now def is_stale(self, now: float) -> bool: """Returns true if this record is at least half way expired.""" + if self._stale_time is None: + self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) return self._stale_time <= now def is_recent(self, now: float) -> bool: """Returns true if the record more than one quarter of its TTL remaining.""" + if self._recent_time is None: + self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) return self._recent_time > now def reset_ttl(self, other: 'DNSRecord') -> None: @@ -194,9 +199,9 @@ def reset_ttl(self, other: 'DNSRecord') -> None: another record.""" self.created = other.created self.ttl = other.ttl - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) + self._expiration_time = None + self._stale_time = None + self._recent_time = None def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use """Abstract method""" From 57c94bb25e056e1827f15c234d7e0bcb5702a0e3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 16:26:39 -1000 Subject: [PATCH 337/608] Remove unreachable BadTypeInNameException check in _ServiceBrowser (#677) --- zeroconf/_services/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 857197bd..19fdccaf 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -168,8 +168,8 @@ def __init__( assert handlers or listener, 'You need to specify at least one handler' self.types = set(type_ if isinstance(type_, list) else [type_]) # type: Set[str] for check_type_ in self.types: - if not check_type_.endswith(service_type_name(check_type_, strict=False)): - raise BadTypeInNameException + # Will generate BadTypeInNameException on a bad name + service_type_name(check_type_, strict=False) self.zc = zc self.addr = addr self.port = port From d3d439ad5d475cff094a4ea83f19d17939527021 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 16:35:57 -1000 Subject: [PATCH 338/608] Allow unregistering a service multiple times (#679) --- tests/services/test_registry.py | 23 +++++++++++++++++++++++ zeroconf/_services/registry.py | 2 ++ 2 files changed, 25 insertions(+) diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 52726a04..496cc629 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -28,6 +28,29 @@ def test_only_register_once(self): registry.remove(info) registry.add(info) + def test_unregister_multiple_times(self): + """Verify we can unregister a service multiple times. + + In production unregister_service and unregister_all_services + may happen at the same time during shutdown. We want to treat + this as non-fatal since its expected to happen and it is unlikely + that the callers know about each other. + """ + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + registry = r.ServiceRegistry() + registry.add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) + registry.remove(info) + registry.remove(info) + def test_lookups(self): type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 244ff294..8ec34120 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -106,6 +106,8 @@ def _add(self, info: ServiceInfo) -> None: def _remove(self, infos: List[ServiceInfo]) -> None: """Remove a services under the lock.""" for info in infos: + if info.key not in self._services: + continue old_service_info = self._services[info.key] self.types[old_service_info.type.lower()].remove(info.key) self.servers[old_service_info.server_key].remove(info.key) From 691c29eeb049e17a12d6f0a6e3bce2c3f8c2aa02 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 16:36:20 -1000 Subject: [PATCH 339/608] Add DNSRRSet class for quick hashtable lookups of records (#678) - This class will be used to do fast checks to see if records should be suppressed by a set of answers. --- tests/test_dns.py | 98 ++++++++++++++++++----------------------------- zeroconf/_dns.py | 20 +++++++++- 2 files changed, 57 insertions(+), 61 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index 664fccde..eab2f2a3 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -15,6 +15,7 @@ import zeroconf as r from zeroconf import DNSIncoming, const, current_time_millis +from zeroconf._dns import DNSRRSet from zeroconf import ( DNSHinfo, DNSText, @@ -900,6 +901,8 @@ def test_dns_record_hashablity_does_not_consider_ttl(): assert len(record_set) == 1 record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') + assert record2 == record3_dupe + assert record2.__hash__() == record3_dupe.__hash__() record_set.add(record3_dupe) assert len(record_set) == 1 @@ -941,6 +944,8 @@ def test_dns_hinfo_record_hashablity(): assert len(record_set) == 2 hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') + assert hinfo2 == hinfo2_dupe + assert hinfo2.__hash__() == hinfo2_dupe.__hash__() record_set.add(hinfo2_dupe) assert len(record_set) == 2 @@ -958,71 +963,13 @@ def test_dns_pointer_record_hashablity(): assert len(record_set) == 2 ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') + assert ptr2 == ptr2 + assert ptr2.__hash__() == ptr2_dupe.__hash__() record_set.add(ptr2_dupe) assert len(record_set) == 2 -def test_dns_text_record_hashablity(): - """Test DNSText are hashable.""" - text1 = r.DNSText('irrelevant', 0, 0, 0, b'12345678901') - text2 = r.DNSText('irrelevant', 1, 0, 0, b'12345678901') - text3 = r.DNSText('irrelevant', 0, 1, 0, b'12345678901') - text4 = r.DNSText('irrelevant', 0, 0, 1, b'12345678901') - text5 = r.DNSText('irrelevant', 0, 0, 0, b'ABCDEFGHIJK') - - record_set = set([text1, text2, text3, text4, text5]) - assert len(record_set) == 5 - - record_set.add(text1) - assert len(record_set) == 5 - - text1_dupe = r.DNSText('irrelevant', 0, 0, 0, b'12345678901') - - record_set.add(text1_dupe) - assert len(record_set) == 5 - - -def test_dns_text_record_hashablity(): - """Test DNSText are hashable.""" - text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') - text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') - text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') - text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') - - record_set = set([text1, text2, text3, text4]) - - assert len(record_set) == 4 - - record_set.add(text1) - assert len(record_set) == 4 - - text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') - - record_set.add(text1_dupe) - assert len(record_set) == 4 - - -def test_dns_text_record_hashablity(): - """Test DNSText are hashable.""" - text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') - text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901') - text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') - text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') - - record_set = set([text1, text2, text3, text4]) - - assert len(record_set) == 4 - - record_set.add(text1) - assert len(record_set) == 4 - - text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') - - record_set.add(text1_dupe) - assert len(record_set) == 4 - - def test_dns_text_record_hashablity(): """Test DNSText are hashable.""" text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') @@ -1038,6 +985,8 @@ def test_dns_text_record_hashablity(): assert len(record_set) == 4 text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901') + assert text1 == text1_dupe + assert text1.__hash__() == text1_dupe.__hash__() record_set.add(text1_dupe) assert len(record_set) == 4 @@ -1060,6 +1009,35 @@ def test_dns_service_record_hashablity(): srv1_dupe = r.DNSService( 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a' ) + assert srv1 == srv1_dupe + assert srv1.__hash__() == srv1_dupe.__hash__() record_set.add(srv1_dupe) assert len(record_set) == 4 + + +def test_rrset_does_not_consider_ttl(): + """Test DNSRRSet does not consider the ttl in the hash.""" + + longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same') + shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same') + longaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 100, b'same') + shortaaaarec = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 10, b'same') + + rrset = DNSRRSet([longarec, shortaaaarec]) + + assert rrset.suppresses(longarec) + assert rrset.suppresses(shortarec) + assert not rrset.suppresses(longaaaarec) + assert rrset.suppresses(shortaaaarec) + + verylongarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1000, b'same') + longarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 100, b'same') + mediumarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 60, b'same') + shortarec = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 10, b'same') + + rrset2 = DNSRRSet([mediumarec]) + assert not rrset2.suppresses(verylongarec) + assert rrset2.suppresses(longarec) + assert rrset2.suppresses(mediumarec) + assert rrset2.suppresses(shortarec) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 073b95f7..ed4390c9 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -23,7 +23,7 @@ import enum import socket import struct -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast from ._exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log @@ -970,3 +970,21 @@ def packets(self) -> List[bytes]: break self.state = self.State.finished return self.packets_data + + +class DNSRRSet: + """A set of dns records independent of the ttl.""" + + def __init__(self, records: Iterable[DNSRecord]) -> None: + """Create an RRset from records.""" + self._records = records + self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None + + def suppresses(self, record: DNSRecord) -> bool: + """Returns true if any answer in the rrset can suffice for the + information held in this record.""" + if self._lookup is None: + # Build the hash table so we can lookup the record independent of the ttl + self._lookup = {record: record for record in self._records} + other = self._lookup.get(record) + return bool(other and other.ttl > (record.ttl / 2)) From e5ea9bb6c0a3bce7d05241f275a205ddd9e6b615 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 16:45:33 -1000 Subject: [PATCH 340/608] Use DNSRRSet for known answer suppression (#680) - DNSRRSet uses hash table lookups under the hood which is much faster than the linear searches used by DNSRecord.suppressed_by --- zeroconf/_handlers.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 15a853b2..f8590e86 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -25,7 +25,7 @@ from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union from ._cache import DNSCache -from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord +from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log from ._services import RecordUpdateListener from ._services.registry import ServiceRegistry @@ -178,7 +178,7 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: def _answer_service_type_enumeration_query( self, - msg: DNSIncoming, + answers_rrset: DNSRRSet, ) -> Set[DNSRecord]: """Provide an answer to a service type enumeration query. @@ -188,68 +188,70 @@ def _answer_service_type_enumeration_query( DNSPointer(_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype) for stype in self.registry.get_types() ) - records -= set(dns_pointer for dns_pointer in records if dns_pointer.suppressed_by(msg)) + records -= set(dns_pointer for dns_pointer in records if answers_rrset.suppresses(dns_pointer)) return records def _add_pointer_answers( - self, name: str, msg: DNSIncoming, answers: Set[DNSRecord], additionals: Set[DNSRecord] + self, name: str, answers_rrset: DNSRRSet, answers: Set[DNSRecord], additionals: Set[DNSRecord] ) -> None: """Answer PTR/ANY question.""" for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer() - if not dns_pointer.suppressed_by(msg): - answers.add(service.dns_pointer()) + if not answers_rrset.suppresses(dns_pointer): + answers.add(dns_pointer) additionals.add(service.dns_service()) additionals.add(service.dns_text()) additionals.update(service.dns_addresses()) - def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord], type_: int) -> None: + def _add_address_answers( + self, name: str, answers_rrset: DNSRRSet, answers: Set[DNSRecord], type_: int + ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): - if not dns_address.suppressed_by(msg): + if not answers_rrset.suppresses(dns_address): answers.add(dns_address) def _answer_question( - self, msg: DNSIncoming, question: DNSQuestion + self, answers_rrset: DNSRRSet, question: DNSQuestion ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: answers: Set[DNSRecord] = set() additionals: Set[DNSRecord] = set() type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, msg, answers, additionals) + self._add_pointer_answers(question.name, answers_rrset, answers, additionals) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, msg, answers, type_) + self._add_address_answers(question.name, answers_rrset, answers, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore if service is not None: if type_ in (_TYPE_SRV, _TYPE_ANY): dns_service = service.dns_service() - if not dns_service.suppressed_by(msg): + if not answers_rrset.suppresses(dns_service): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.2. answers.add(service.dns_service()) additionals.update(service.dns_addresses()) if type_ in (_TYPE_TXT, _TYPE_ANY): dns_text = service.dns_text() - if not dns_text.suppressed_by(msg): + if not answers_rrset.suppresses(dns_text): answers.add(service.dns_text()) return answers, additionals def _answer_any_question( - self, msg: DNSIncoming, question: DNSQuestion + self, answers_rrset: DNSRRSet, question: DNSQuestion ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: empty_additionals: Set[DNSRecord] = set() - return self._answer_service_type_enumeration_query(msg), empty_additionals + return self._answer_service_type_enumeration_query(answers_rrset), empty_additionals - return self._answer_question(msg, question) + return self._answer_question(answers_rrset, question) def response( # pylint: disable=unused-argument self, msg: DNSIncoming, addr: Optional[str], port: int @@ -257,9 +259,10 @@ def response( # pylint: disable=unused-argument """Deal with incoming query packets. Provides a response if possible.""" ucast_source = port != _MDNS_PORT query_res = _QueryResponse(self.cache, msg, ucast_source) + answers_rrset = DNSRRSet(msg.answers) for question in msg.questions: - all_answers = self._answer_any_question(msg, question) + all_answers = self._answer_any_question(answers_rrset, question) if not ucast_source and question.unicast: query_res.add_qu_question_response(*all_answers) else: From d2b5e51d0dcde801e171a4c1e43ef1f86abde825 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 17:12:06 -1000 Subject: [PATCH 341/608] Check if SO_REUSEPORT exists instead of using an exception catch (#682) --- zeroconf/_utils/net.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 963faf55..8d1b60bc 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -196,13 +196,9 @@ def new_socket( # pylint: disable=too-many-branches # versions of Python have SO_REUSEPORT available. # Catch OSError and socket.error for kernel versions <3.9 because lacking # SO_REUSEPORT support. - try: - reuseport = socket.SO_REUSEPORT - except AttributeError: - pass - else: + if hasattr(socket, 'SO_REUSEPORT'): try: - s.setsockopt(socket.SOL_SOCKET, reuseport, 1) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member except OSError as err: if err.errno != errno.ENOPROTOOPT: raise From 00b972c062fd0ed3f2fcc4ceaec84c43b9a613be Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 17:12:26 -1000 Subject: [PATCH 342/608] Fix logic reversal in apple_p2p test (#681) --- tests/test_core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 1e6d0d92..1a48c5ea 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -100,16 +100,16 @@ def test_launch_and_close_v6_only(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.Default, ip_version=r.IPVersion.V6Only) rv.close() - @unittest.skipIf(sys.platform != 'darwin', reason="apple_p2p only available on mac") - def test_launch_and_close_apple_p2p(self): - rv = r.Zeroconf(apple_p2p=True) - rv.close() - - @unittest.skipIf(sys.platform == 'darwin', reason="apple_p2p available on mac") - def test_launch_and_close_apple_p2p(self): + @unittest.skipIf(sys.platform == 'darwin', reason="apple_p2p failure path not testable on mac") + def test_launch_and_close_apple_p2p_not_mac(self): with pytest.raises(RuntimeError): r.Zeroconf(apple_p2p=True) + @unittest.skipIf(sys.platform != 'darwin', reason="apple_p2p happy path only testable on mac") + def test_launch_and_close_apple_p2p_on_mac(self): + rv = r.Zeroconf(apple_p2p=True) + rv.close() + def test_handle_response(self): def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: ttl = 120 From 95ddb36de64ddf3be9e93f07a1daa8389410f73d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 17:19:24 -1000 Subject: [PATCH 343/608] Add coverage to verify ServiceInfo tolerates bytes or string in the txt record (#683) --- tests/test_services.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_services.py b/tests/test_services.py index 45a5b91e..521efd1f 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1332,3 +1332,48 @@ def test_serviceinfo_address_updates(): ) info_service.addresses = [socket.inet_aton("10.0.1.3")] assert info_service.addresses == [socket.inet_aton("10.0.1.3")] + + +def test_serviceinfo_accepts_bytes_or_string_dict(): + """Verify a bytes or string dict can be passed to ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + addresses = [socket.inet_aton("10.0.1.2")] + server_name = "ash-2.local." + info_service = ServiceInfo( + type_, '%s.%s' % (name, type_), 80, 0, 0, {b'path': b'/~paulsm/'}, server_name, addresses=addresses + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {b'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': b'/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' From 6fd1bf2364da4fc2949a905d2e4acb7da003e84d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 17:39:48 -1000 Subject: [PATCH 344/608] Update changelog (#684) --- README.rst | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/README.rst b/README.rst index 07a9b2d6..06a6c64e 100644 --- a/README.rst +++ b/README.rst @@ -219,6 +219,49 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Check if SO_REUSEPORT exists instead of using an exception catch (#682) @bdraco + +* Use DNSRRSet for known answer suppression (#680) @bdraco + + DNSRRSet uses hash table lookups under the hood which + is much faster than the linear searches used by + DNSRecord.suppressed_by + +* Add DNSRRSet class for quick hashtable lookups of records (#678) @bdraco + + This class will be used to do fast checks to see + if records should be suppressed by a set of answers. + +* Allow unregistering a service multiple times (#679) @bdraco + +* Remove unreachable BadTypeInNameException check in _ServiceBrowser (#677) @bdraco + +* Update async_browser.py example to use AsyncZeroconfServiceTypes (#665) @bdraco + +* Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.aio (#658) @bdraco + +* Remove all calls to the executor in AsyncZeroconf (#653) @bdraco + +* Set __all__ in zeroconf.aio to ensure private functions do now show in the docs (#652) @bdraco + +* Ensure interface_index_to_ip6_address skips ipv4 adapters (#651) @bdraco + +* Add async_unregister_all_services to AsyncZeroconf (#649) @bdraco + +* Ensure services are removed from the registry when calling unregister_all_services (#644) @bdraco + + There was a race condition where a query could be answered for a service + in the registry while goodbye packets which could result a fresh record + being broadcast after the goodbye if a query came in at just the right + time. To avoid this, we now remove the services from the registry right + after we generate the goodbye packet + +* Use ServiceInfo.key/ServiceInfo.server_key instead of lowering in ServiceRegistry (#647) @bdraco + +* Ensure the ServiceInfo.key gets updated when the name is changed externally (#645) @bdraco + +* Ensure AsyncZeroconf.async_close can be called multiple times like Zeroconf.close (#638) @bdraco + * Ensure eventloop shutdown is threadsafe (#636) @bdraco * Return early in the shutdown/close process (#632) @bdraco From e816053af4d900f57100c07c48f384165ba28b9a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 19:17:54 -1000 Subject: [PATCH 345/608] Add truncated property to DNSMessage to lookup the TC bit (#686) --- tests/test_dns.py | 24 ++++++++++++------------ zeroconf/_dns.py | 6 ++++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index eab2f2a3..4096aa94 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -378,15 +378,15 @@ def test_many_questions_with_many_known_answers(self): parsed1 = r.DNSIncoming(packets[0]) assert len(parsed1.questions) == 30 assert len(parsed1.answers) == 88 - assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC + assert parsed1.truncated parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 assert len(parsed2.answers) == 101 - assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC + assert parsed2.truncated parsed3 = r.DNSIncoming(packets[2]) assert len(parsed3.questions) == 0 assert len(parsed3.answers) == 11 - assert parsed3.flags & const._FLAGS_TC == 0 + assert not parsed3.truncated def test_massive_probe_packet_split(self): """Test probe with many authorative answers.""" @@ -419,15 +419,15 @@ def test_massive_probe_packet_split(self): assert parsed1.questions[0].unicast is True assert len(parsed1.questions) == 30 assert parsed1.num_authorities == 88 - assert parsed1.flags & const._FLAGS_TC == const._FLAGS_TC + assert parsed1.truncated parsed2 = r.DNSIncoming(packets[1]) assert len(parsed2.questions) == 0 assert parsed2.num_authorities == 101 - assert parsed2.flags & const._FLAGS_TC == const._FLAGS_TC + assert parsed2.truncated parsed3 = r.DNSIncoming(packets[2]) assert len(parsed3.questions) == 0 assert parsed3.num_authorities == 11 - assert parsed3.flags & const._FLAGS_TC == 0 + assert not parsed3.truncated def test_only_one_answer_can_by_large(self): """Test that only the first answer in each packet can be large. @@ -823,15 +823,15 @@ def test_tc_bit_in_query_packet(): assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & const._FLAGS_TC == const._FLAGS_TC + assert first_packet.truncated assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & const._FLAGS_TC == const._FLAGS_TC + assert second_packet.truncated assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & const._FLAGS_TC == 0 + assert not third_packet.truncated assert third_packet.valid is True @@ -855,15 +855,15 @@ def test_tc_bit_not_set_in_answer_packet(): assert len(packets) == 3 first_packet = r.DNSIncoming(packets[0]) - assert first_packet.flags & const._FLAGS_TC == 0 + assert not first_packet.truncated assert first_packet.valid is True second_packet = r.DNSIncoming(packets[1]) - assert second_packet.flags & const._FLAGS_TC == 0 + assert not second_packet.truncated assert second_packet.valid is True third_packet = r.DNSIncoming(packets[2]) - assert third_packet.flags & const._FLAGS_TC == 0 + assert not third_packet.truncated assert third_packet.valid is True diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index ed4390c9..e4b1080f 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -395,6 +395,11 @@ def is_response(self) -> bool: """Returns true if this is a response.""" return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + @property + def truncated(self) -> bool: + """Returns true if this is a truncated.""" + return (self.flags & _FLAGS_TC) == _FLAGS_TC + class DNSIncoming(DNSMessage, QuietLogger): @@ -428,6 +433,7 @@ def __repr__(self) -> str: [ 'id=%s' % self.id, 'flags=%s' % self.flags, + 'truncated=%s' % self.truncated, 'n_q=%s' % self.num_questions, 'n_ans=%s' % self.num_answers, 'n_auth=%s' % self.num_authorities, From 4865d2ba782d0313c0f7d878f5887453086febaa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 15 Jun 2021 23:47:55 -1000 Subject: [PATCH 346/608] Remove sleeps from services types test (#688) - Instead of registering the services and doing the broadcast we now put them in the registry directly. --- tests/services/test_types.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 8cff4431..6f6645db 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -35,10 +35,7 @@ def test_integration_with_listener(self): "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - zeroconf_registrar.register_service(info) - # Ensure we do not clear the cache until after the last broadcast is processed - time.sleep(0.2) - _clear_cache(zeroconf_registrar) + zeroconf_registrar.registry.add(info) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types @@ -70,10 +67,7 @@ def test_integration_with_listener_v6_records(self): "ash-2.local.", addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) - zeroconf_registrar.register_service(info) - # Ensure we do not clear the cache until after the last broadcast is processed - time.sleep(0.2) - _clear_cache(zeroconf_registrar) + zeroconf_registrar.registry.add(info) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert type_ in service_types @@ -104,10 +98,7 @@ def test_integration_with_listener_ipv6(self): "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - zeroconf_registrar.register_service(info) - # Ensure we do not clear the cache until after the last broadcast is processed - time.sleep(0.2) - _clear_cache(zeroconf_registrar) + zeroconf_registrar.registry.add(info) try: service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) assert type_ in service_types @@ -138,10 +129,7 @@ def test_integration_with_subtype_and_listener(self): "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - zeroconf_registrar.register_service(info) - # Ensure we do not clear the cache until after the last broadcast is processed - time.sleep(0.2) - _clear_cache(zeroconf_registrar) + zeroconf_registrar.registry.add(info) try: service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) assert discovery_type in service_types From 8a25a44ec5e4f21c6bdb282fefb8f6c2d296a70b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 00:03:55 -1000 Subject: [PATCH 347/608] Implement multi-packet known answer supression (#687) - Implements https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - Fixes https://github.com/jstasiak/python-zeroconf/issues/499 --- tests/test_core.py | 184 ++++++++++++++++++++++++++++++++++++++++- tests/test_handlers.py | 144 +++++++++++++++++++++++--------- zeroconf/_core.py | 38 ++++++++- zeroconf/_handlers.py | 8 +- 4 files changed, 324 insertions(+), 50 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 1a48c5ea..8f459da1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._core """ +import asyncio import itertools import logging import os @@ -18,7 +19,7 @@ import zeroconf as r from zeroconf import _core, const, ServiceBrowser, Zeroconf -from . import has_working_ipv6, _inject_response +from . import has_working_ipv6, _clear_cache, _inject_response log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -423,3 +424,184 @@ def test_sending_unicast(): assert zc.cache.get(entry) is not None zc.close() + + +def test_tc_bit_defers(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_tcbitdefer._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = "%s.%s" % (name, type_) + registration2_name = "%s.%s" % (name2, type_) + registration3_name = "%s.%s" % (name3, type_) + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = r.ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = r.ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info) + zc.registry.add(info2) + zc.registry.add(info3) + + def threadsafe_query(*args): + async def make_query(): + zc.handle_query(*args) + + asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() + + now = r.current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(300): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) == 4 + expected_deferred = [] + source_ip = '203.0.113.13' + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert source_ip not in zc._deferred + assert source_ip not in zc._timers + + # unregister + zc.unregister_service(info) + zc.close() + + +def test_tc_bit_defers_last_response_missing(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_knowndefer._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = "%s.%s" % (name, type_) + registration2_name = "%s.%s" % (name2, type_) + registration3_name = "%s.%s" % (name3, type_) + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = r.ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = r.ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = r.ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info) + zc.registry.add(info2) + zc.registry.add(info3) + + def threadsafe_query(*args): + async def make_query(): + zc.handle_query(*args) + + asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() + + now = r.current_time_millis() + _clear_cache(zc) + source_ip = '203.0.113.12' + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(300): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) == 4 + expected_deferred = [] + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + timer1 = zc._timers[source_ip] + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + timer2 = zc._timers[source_ip] + if sys.version_info >= (3, 7): + assert timer1.cancelled() + assert timer2 != timer1 + + # Send the same packet again to similar multi interfaces + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + timer3 = zc._timers[source_ip] + if sys.version_info >= (3, 7): + assert not timer3.cancelled() + assert timer3 == timer2 + + next_packet = r.DNSIncoming(packets.pop(0)) + expected_deferred.append(next_packet) + threadsafe_query(next_packet, source_ip, const._MDNS_PORT) + assert zc._deferred[source_ip] == expected_deferred + assert source_ip in zc._timers + timer4 = zc._timers[source_ip] + if sys.version_info >= (3, 7): + assert timer3.cancelled() + assert timer4 != timer3 + + for _ in range(7): + time.sleep(0.1) + if source_ip not in zc._timers: + break + + assert source_ip not in zc._deferred + assert source_ip not in zc._timers + + # unregister + zc.registry.remove(info) + zc.close() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 1e8109eb..fcf094c5 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -96,9 +96,9 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - multicast_out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[ - 1 - ] + multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT + )[1] _process_outgoing_packet(multicast_out) # The additonals should all be suppresed since they are all in the answers section @@ -134,7 +134,9 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) _process_outgoing_packet( - zc.query_handler.response(r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT)[1] + zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT + )[1] ) assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -231,7 +233,7 @@ def test_ptr_optimization(): query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert unicast_out is None assert multicast_out is None @@ -243,7 +245,7 @@ def test_ptr_optimization(): query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert multicast_out.id == query.id assert unicast_out is None @@ -271,49 +273,52 @@ def test_ptr_optimization(): def test_any_query_for_ptr(): """Test that queries for ANY will return PTR records.""" zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_knownservice._tcp.local." + type_ = "_anyptr._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) - zc.register_service(info) + zc.registry.add(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) + _, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) assert multicast_out.answers[0][0].name == type_ assert multicast_out.answers[0][0].alias == registration_name # unregister - zc.unregister_service(info) + zc.registry.remove(info) zc.close() def test_aaaa_query(): """Test that queries for AAAA records work.""" zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_knownservice._tcp.local." + type_ = "_knownaaaservice._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) - zc.register_service(info) + zc.registry.add(info) - _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT) + _, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) assert multicast_out.answers[0][0].address == ipv6_address # unregister - zc.unregister_service(info) + zc.registry.remove(info) zc.close() @@ -331,13 +336,15 @@ def test_unicast_response(): type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) # register - zc.register_service(info) + zc.registry.add(info) _clear_cache(zc) # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.response(r.DNSIncoming(query.packets()[0]), "1.2.3.4", 1234) + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", 1234 + ) for out in (unicast_out, multicast_out): assert out.id == query.id has_srv = has_txt = has_a = False @@ -356,7 +363,7 @@ def test_unicast_response(): assert has_srv and has_txt and has_a # unregister - zc.unregister_service(info) + zc.registry.remove(info) zc.close() @@ -413,7 +420,7 @@ def _validate_complete_response(query, out): query.add_question(question) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None _validate_complete_response(query, unicast_out) @@ -426,7 +433,7 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None _validate_complete_response(query, multicast_out) @@ -439,7 +446,7 @@ def _validate_complete_response(query, out): query.add_question(question) query.add_authorative_answer(info2.dns_pointer()) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) _validate_complete_response(query, unicast_out) _validate_complete_response(query, multicast_out) @@ -452,7 +459,7 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(query.packets()[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None _validate_complete_response(query, unicast_out) @@ -463,7 +470,7 @@ def _validate_complete_response(query, out): def test_known_answer_supression(): zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_knownservice._tcp.local." + type_ = "_knownanswersv8._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} @@ -471,7 +478,7 @@ def test_known_answer_supression(): info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.register_service(info) + zc.registry.add(info) now = current_time_millis() _clear_cache(zc) @@ -481,7 +488,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is not None and multicast_out.answers @@ -492,7 +499,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_pointer(), now) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None # If the answer is suppressed, the additional should be suppresed as well @@ -504,7 +511,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is not None and multicast_out.answers @@ -516,7 +523,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(dns_address, now) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert not multicast_out or not multicast_out.answers @@ -527,7 +534,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is not None and multicast_out.answers @@ -538,7 +545,7 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_service(), now) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None # If the answer is suppressed, the additional should be suppresed as well @@ -550,7 +557,7 @@ def test_known_answer_supression(): generated.add_question(question) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is not None and multicast_out.answers @@ -561,19 +568,73 @@ def test_known_answer_supression(): generated.add_answer_at_time(info.dns_text(), now) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert not multicast_out or not multicast_out.answers # unregister - zc.unregister_service(info) + zc.registry.remove(info) + zc.close() + + +def test_multi_packet_known_answer_supression(): + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_handlermultis._tcp.local." + name = "knownname" + name2 = "knownname2" + name3 = "knownname3" + + registration_name = "%s.%s" % (name, type_) + registration2_name = "%s.%s" % (name2, type_) + registration3_name = "%s.%s" % (name3, type_) + + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + server_name2 = "ash-3.local." + server_name3 = "ash-4.local." + + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + info3 = ServiceInfo( + type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info) + zc.registry.add(info2) + zc.registry.add(info3) + + now = current_time_millis() + _clear_cache(zc) + # Test PTR supression + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + for _ in range(1000): + # Add so many answers we end up with another packet + generated.add_answer_at_time(info.dns_pointer(), now) + generated.add_answer_at_time(info2.dns_pointer(), now) + generated.add_answer_at_time(info3.dns_pointer(), now) + packets = generated.packets() + assert len(packets) > 1 + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is None + # unregister + zc.registry.remove(info) + zc.registry.remove(info2) + zc.registry.remove(info3) zc.close() def test_known_answer_supression_service_type_enumeration_query(): zc = Zeroconf(interfaces=['127.0.0.1']) - type_ = "_knownservice._tcp.local." + type_ = "_otherknown._tcp.local." name = "knownname" registration_name = "%s.%s" % (name, type_) desc = {'path': '/~paulsm/'} @@ -581,17 +642,17 @@ def test_known_answer_supression_service_type_enumeration_query(): info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.register_service(info) + zc.registry.add(info) - type_2 = "_knownservice2._tcp.local." + type_2 = "_otherknown2._tcp.local." name = "knownname" registration_name2 = "%s.%s" % (name, type_2) desc = {'path': '/~paulsm/'} server_name2 = "ash-3.local." - info = ServiceInfo( + info2 = ServiceInfo( type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.register_service(info) + zc.registry.add(info2) now = current_time_millis() _clear_cache(zc) @@ -601,7 +662,7 @@ def test_known_answer_supression_service_type_enumeration_query(): generated.add_question(question) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert multicast_out is not None and multicast_out.answers @@ -631,11 +692,12 @@ def test_known_answer_supression_service_type_enumeration_query(): ) packets = generated.packets() unicast_out, multicast_out = zc.query_handler.response( - r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None assert not multicast_out or not multicast_out.answers # unregister - zc.unregister_service(info) + zc.registry.remove(info) + zc.registry.remove(info2) zc.close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 21bc9a6a..6aef35a3 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -24,6 +24,7 @@ import contextlib import errno import itertools +import random import socket import sys import threading @@ -71,6 +72,8 @@ _UNREGISTER_TIME, ) +_TC_DELAY_RANDOM_INTERVAL = (400, 500) + class NotifyListener: """Receive notifications Zeroconf.notify_all is called.""" @@ -302,6 +305,9 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None + self._deferred: Dict[str, List[DNSIncoming]] = {} + self._timers: Dict[str, asyncio.TimerHandle] = {} + self.start() def start(self) -> None: @@ -557,13 +563,37 @@ def handle_response(self, msg: DNSIncoming) -> None: are held in the cache, and listeners are notified.""" self.record_manager.updates_from_response(msg) - def handle_query(self, msg: DNSIncoming, addr: Optional[str], port: int) -> None: + def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: """Deal with incoming query packets. Provides a response if possible.""" - unicast_out, multicast_out = self.query_handler.response(msg, addr, port) - if unicast_out and unicast_out.answers: + if not msg.truncated: + self._respond_query(msg, addr, port) + return + + deferred = self._deferred.setdefault(addr, []) + # If we get the same packet on another iterface we ignore it + for incoming in reversed(deferred): + if incoming.data == msg.data: + return + deferred.append(msg) + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) + assert self.loop is not None + if addr in self._timers: + self._timers.pop(addr).cancel() + self._timers[addr] = self.loop.call_later(delay, self._respond_query, None, addr, port) + + def _respond_query(self, msg: Optional[DNSIncoming], addr: str, port: int) -> None: + """Respond to a query and reassemble any truncated deferred packets.""" + if addr in self._timers: + self._timers.pop(addr).cancel() + packets = self._deferred.pop(addr, []) + if msg: + packets.append(msg) + + unicast_out, multicast_out = self.query_handler.response(packets, addr, port) + if unicast_out: self.async_send(unicast_out, addr, port) - if multicast_out and multicast_out.answers: + if multicast_out: self.async_send(multicast_out, None, _MDNS_PORT) def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f8590e86..a7e18360 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -254,14 +254,14 @@ def _answer_any_question( return self._answer_question(answers_rrset, question) def response( # pylint: disable=unused-argument - self, msg: DNSIncoming, addr: Optional[str], port: int + self, msgs: List[DNSIncoming], addr: Optional[str], port: int ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: """Deal with incoming query packets. Provides a response if possible.""" ucast_source = port != _MDNS_PORT - query_res = _QueryResponse(self.cache, msg, ucast_source) - answers_rrset = DNSRRSet(msg.answers) + query_res = _QueryResponse(self.cache, msgs[0], ucast_source) + answers_rrset = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) - for question in msg.questions: + for question in itertools.chain(*[msg.questions for msg in msgs]): all_answers = self._answer_any_question(answers_rrset, question) if not ucast_source and question.unicast: query_res.add_qu_question_response(*all_answers) From b60f307d59e342983d1baa6040c3d997f84538ab Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 11:35:01 -1000 Subject: [PATCH 348/608] Remove AA flags from handlers test (#693) - The flag was added by mistake when copying from other tests --- tests/test_handlers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index fcf094c5..8820b185 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -230,7 +230,7 @@ def test_ptr_optimization(): zc.register_service(info) # Verify we won't respond for 1s with the same multicast - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT @@ -242,7 +242,7 @@ def test_ptr_optimization(): _clear_cache(zc) # Verify we will now respond - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT @@ -340,7 +340,7 @@ def test_unicast_response(): _clear_cache(zc) # query - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) unicast_out, multicast_out = zc.query_handler.response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", 1234 @@ -413,7 +413,7 @@ def _validate_complete_response(query, out): assert has_srv and has_txt and has_a # With QU should respond to only unicast when the answer has been recently multicast - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) question.unique = True # Set the QU bit assert question.unicast is True @@ -427,7 +427,7 @@ def _validate_complete_response(query, out): _clear_cache(zc) # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) question.unique = True # Set the QU bit assert question.unicast is True @@ -439,7 +439,7 @@ def _validate_complete_response(query, out): _validate_complete_response(query, multicast_out) # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) question.unique = True # Set the QU bit assert question.unicast is True @@ -453,7 +453,7 @@ def _validate_complete_response(query, out): _inject_response(zc, r.DNSIncoming(multicast_out.packets()[0])) # With the cache repopulated; should respond to only unicast when the answer has been recently multicast - query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) question.unique = True # Set the QU bit assert question.unicast is True From 993a82e414db8aadaee0e0475e178e75df417a71 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 11:35:11 -1000 Subject: [PATCH 349/608] Move setting DNS created and ttl into its own function (#692) --- zeroconf/_dns.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index e4b1080f..b55f77f0 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -197,8 +197,12 @@ def is_recent(self, now: float) -> bool: def reset_ttl(self, other: 'DNSRecord') -> None: """Sets this record's TTL and created time to that of another record.""" - self.created = other.created - self.ttl = other.ttl + self._set_created_ttl(other.created, other.ttl) + + def _set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: + """Set the created and ttl of a record.""" + self.created = created + self.ttl = ttl self._expiration_time = None self._stale_time = None self._recent_time = None From 0cdba98e65dd3dce2db8aa607e97e3b67b97721a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 12:49:12 -1000 Subject: [PATCH 350/608] Suppress additionals when answer is suppressed (#690) --- tests/test_handlers.py | 128 +++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 167 +++++++++++++++++------------------------ 2 files changed, 198 insertions(+), 97 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 8820b185..11c077b4 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -701,3 +701,131 @@ def test_known_answer_supression_service_type_enumeration_query(): zc.registry.remove(info) zc.registry.remove(info2) zc.close() + + +def test_qu_response_only_sends_additionals_if_sends_answer(): + """Test that a QU response does not send additionals unless it sends the answer as well.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + type_ = "_addtest1._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info) + + type_2 = "_addtest2._tcp.local." + name = "knownname" + registration_name2 = "%s.%s" % (name, type_2) + desc = {'path': '/~paulsm/'} + server_name2 = "ash-3.local." + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.add(info2) + + ptr_record = info.dns_pointer() + + # Add the PTR record to the cache + zc.cache.add(ptr_record) + + # Add the A record to the cache with 50% ttl remaining + a_record = info.dns_addresses()[0] + a_record._set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl) + assert not a_record.is_recent(current_time_millis()) + zc.cache.add(a_record) + + # With QU should respond to only unicast when the answer has been recently multicast + # even if the additional has not been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out is None + assert a_record in unicast_out.additionals + assert unicast_out.answers[0][0] == ptr_record + + # Remove the 50% A record and add a 100% A record + zc.cache.remove(a_record) + a_record = info.dns_addresses()[0] + assert a_record.is_recent(current_time_millis()) + zc.cache.add(a_record) + # With QU should respond to only unicast when the answer has been recently multicast + # even if the additional has not been recently multicast + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out is None + assert a_record in unicast_out.additionals + assert unicast_out.answers[0][0] == ptr_record + + # Remove the 100% PTR record and add a 50% PTR record + zc.cache.remove(ptr_record) + ptr_record._set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl) + assert not ptr_record.is_recent(current_time_millis()) + zc.cache.add(ptr_record) + # With QU should respond to only multicast since the has less + # than 75% of its ttl remaining + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out.answers[0][0] == ptr_record + assert a_record in multicast_out.additionals + assert info.dns_text() in multicast_out.additionals + assert info.dns_service() in multicast_out.additionals + + assert unicast_out is None + + # Ask 2 QU questions, with info the PTR is at 50%, with info2 the PTR is at 100% + # We should get back a unicast reply for info2, but info should be multicasted since its within 75% of its TTL + # With QU should respond to only multicast since the has less + # than 75% of its ttl remaining + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + + question = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + question.unique = True # Set the QU bit + assert question.unicast is True + query.add_question(question) + zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache + + unicast_out, multicast_out = zc.query_handler.response( + [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + ) + assert multicast_out.answers[0][0] == info.dns_pointer() + assert info.dns_addresses()[0] in multicast_out.additionals + assert info.dns_text() in multicast_out.additionals + assert info.dns_service() in multicast_out.additionals + + assert unicast_out.answers[0][0] == info2.dns_pointer() + assert info2.dns_addresses()[0] in unicast_out.additionals + assert info2.dns_text() in unicast_out.additionals + assert info2.dns_service() in unicast_out.additionals + + # unregister + zc.registry.remove(info) + zc.close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index a7e18360..042ab2a1 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -20,7 +20,6 @@ USA """ -import enum import itertools from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union @@ -53,14 +52,7 @@ from ._core import Zeroconf # pylint: disable=cyclic-import -@enum.unique -class RecordSetKeys(enum.Enum): - Answers = 1 - Additionals = 2 - - -# Switch to a TypedDict once Python 3.8 is the minimum supported version -_RecordSetType = Dict[RecordSetKeys, Set[DNSRecord]] +_AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] class _QueryResponse: @@ -69,48 +61,39 @@ class _QueryResponse: def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> None: """Build a query response.""" self._msg = msg - self._ucast_source = ucast_source self._is_probe = msg.num_authorities > 0 + self._ucast_source = ucast_source self._now = current_time_millis() self._cache = cache - self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} - self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()} + self._additionals: _AnswerWithAdditionalsType = {} + self._ucast: Set[DNSRecord] = set() + self._mcast: Set[DNSRecord] = set() - def add_qu_question_response( - self, - answers: Set[DNSRecord], - additionals: Set[DNSRecord], - ) -> None: + def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast QU query.""" - self._add_qu_question_response_to_target(answers, RecordSetKeys.Answers) - self._add_qu_question_response_to_target(additionals, RecordSetKeys.Additionals) - - def _add_qu_question_response_to_target(self, target: Set[DNSRecord], answer_type: RecordSetKeys) -> None: - """Add part of the QU response.""" - for record in target: + for record, additionals in answers.items(): + self._additionals[record] = additionals if self._is_probe: - self._ucast[answer_type].add(record) + self._ucast.add(record) if not self._has_mcast_within_one_quarter_ttl(record): - self._mcast[answer_type].add(record) + self._mcast.add(record) elif not self._is_probe: - self._ucast[answer_type].add(record) + self._ucast.add(record) - def add_ucast_question_response(self, answers: Set[DNSRecord], additionals: Set[DNSRecord]) -> None: + def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a unicast query.""" - self._ucast[RecordSetKeys.Answers].update(answers) - self._ucast[RecordSetKeys.Additionals].update(additionals) + self._additionals.update(answers) + self._ucast.update(answers.keys()) - def add_mcast_question_response(self, answers: Set[DNSRecord], additionals: Set[DNSRecord]) -> None: + def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast query.""" - self._mcast[RecordSetKeys.Answers].update(answers) - self._mcast[RecordSetKeys.Additionals].update(additionals) + self._additionals.update(answers) + self._mcast.update(answers.keys()) def outgoing_unicast(self) -> Optional[DNSOutgoing]: """Build the outgoing unicast response.""" ucastout = self._construct_outgoing_from_record_set(self._ucast, False) - # Adding the questions back when the source is - # unicast (not MDNS port) is legacy behavior - # Is this correct? + # Adding the questions back when the source is legacy unicast behavior if ucastout and self._ucast_source: for question in self._msg.questions: ucastout.add_question(question) @@ -119,28 +102,33 @@ def outgoing_unicast(self) -> Optional[DNSOutgoing]: def outgoing_multicast(self) -> Optional[DNSOutgoing]: """Build the outgoing multicast response.""" if not self._is_probe: - self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Answers]) - self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Additionals]) + self._suppress_mcasts_from_last_second(self._mcast) return self._construct_outgoing_from_record_set(self._mcast, True) def _construct_outgoing_from_record_set( - self, rrset: _RecordSetType, multicast: bool + self, answers_rrset: Set[DNSRecord], multicast: bool ) -> Optional[DNSOutgoing]: """Add answers and additionals to a DNSOutgoing.""" - if not rrset[RecordSetKeys.Answers] and not rrset[RecordSetKeys.Additionals]: + # Find additionals and suppress any additionals that are already in answers + additionals_rrset = self._additionals_from_answers_rrset(answers_rrset) - answers_rrset + if not answers_rrset: return None - # Suppress any additionals that are already in answers - rrset[RecordSetKeys.Additionals] -= rrset[RecordSetKeys.Answers] - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=multicast, id_=self._msg.id) - for answer in rrset[RecordSetKeys.Answers]: + for answer in answers_rrset: out.add_answer_at_time(answer, 0) - for additional in rrset[RecordSetKeys.Additionals]: + for additional in additionals_rrset: out.add_additional_answer(additional) - return out + def _additionals_from_answers_rrset(self, rrset: Set[DNSRecord]) -> Set[DNSRecord]: + additionals: Set[DNSRecord] = set() + return additionals.union(*[self._additionals[record] for record in rrset]) + + def _suppress_mcasts_from_last_second(self, rrset: Set[DNSRecord]) -> None: + """Remove any records that were already sent in the last second.""" + rrset -= set(record for record in rrset if self._has_mcast_record_in_last_second(record)) + def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: """Check to see if a record has been mcasted recently. @@ -155,10 +143,6 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: maybe_entry = self._cache.get(record) return bool(maybe_entry and maybe_entry.is_recent(self._now)) - def _suppress_mcasts_from_last_second(self, records: Set[DNSRecord]) -> None: - """Remove any records that were already sent in the last second.""" - records -= set(record for record in records if self._has_mcast_record_in_last_second(record)) - def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: """Remove answers that were just broadcast Protect the network against excessive packet flooding @@ -176,101 +160,90 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: self.registry = registry self.cache = cache - def _answer_service_type_enumeration_query( - self, - answers_rrset: DNSRRSet, - ) -> Set[DNSRecord]: + def _add_service_type_enumeration_query_answers( + self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + ) -> None: """Provide an answer to a service type enumeration query. https://datatracker.ietf.org/doc/html/rfc6763#section-9 """ - records: Set[DNSRecord] = set( - DNSPointer(_SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype) - for stype in self.registry.get_types() - ) - records -= set(dns_pointer for dns_pointer in records if answers_rrset.suppresses(dns_pointer)) - return records + for stype in self.registry.get_types(): + dns_pointer = DNSPointer( + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype + ) + if not known_answers.suppresses(dns_pointer): + answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answers_rrset: DNSRRSet, answers: Set[DNSRecord], additionals: Set[DNSRecord] + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet ) -> None: """Answer PTR/ANY question.""" for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer() - if not answers_rrset.suppresses(dns_pointer): - answers.add(dns_pointer) - additionals.add(service.dns_service()) - additionals.add(service.dns_text()) - additionals.update(service.dns_addresses()) + if not known_answers.suppresses(dns_pointer): + answer_set[dns_pointer] = set( + [service.dns_service(), service.dns_text(), *service.dns_addresses()] + ) def _add_address_answers( - self, name: str, answers_rrset: DNSRRSet, answers: Set[DNSRecord], type_: int + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: int ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): - if not answers_rrset.suppresses(dns_address): - answers.add(dns_address) + if not known_answers.suppresses(dns_address): + answer_set[dns_address] = set() def _answer_question( - self, answers_rrset: DNSRRSet, question: DNSQuestion - ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: - answers: Set[DNSRecord] = set() - additionals: Set[DNSRecord] = set() + self, question: DNSQuestion, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + ) -> None: + if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: + self._add_service_type_enumeration_query_answers(answer_set, known_answers) + return + type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answers_rrset, answers, additionals) + self._add_pointer_answers(question.name, answer_set, known_answers) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answers_rrset, answers, type_) + self._add_address_answers(question.name, answer_set, known_answers, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore if service is not None: if type_ in (_TYPE_SRV, _TYPE_ANY): + # Add recommended additional answers according to + # https://tools.ietf.org/html/rfc6763#section-12.2. dns_service = service.dns_service() - if not answers_rrset.suppresses(dns_service): - # Add recommended additional answers according to - # https://tools.ietf.org/html/rfc6763#section-12.2. - answers.add(service.dns_service()) - additionals.update(service.dns_addresses()) + if not known_answers.suppresses(dns_service): + answer_set[dns_service] = set(service.dns_addresses()) if type_ in (_TYPE_TXT, _TYPE_ANY): dns_text = service.dns_text() - if not answers_rrset.suppresses(dns_text): - answers.add(service.dns_text()) - - return answers, additionals - - def _answer_any_question( - self, answers_rrset: DNSRRSet, question: DNSQuestion - ) -> Tuple[Set[DNSRecord], Set[DNSRecord]]: - if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - empty_additionals: Set[DNSRecord] = set() - return self._answer_service_type_enumeration_query(answers_rrset), empty_additionals - - return self._answer_question(answers_rrset, question) + if not known_answers.suppresses(dns_text): + answer_set[dns_text] = set() def response( # pylint: disable=unused-argument self, msgs: List[DNSIncoming], addr: Optional[str], port: int ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: """Deal with incoming query packets. Provides a response if possible.""" ucast_source = port != _MDNS_PORT + known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) - answers_rrset = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) for question in itertools.chain(*[msg.questions for msg in msgs]): - all_answers = self._answer_any_question(answers_rrset, question) + answer_set: _AnswerWithAdditionalsType = {} + self._answer_question(question, answer_set, known_answers) if not ucast_source and question.unicast: - query_res.add_qu_question_response(*all_answers) + query_res.add_qu_question_response(answer_set) else: if ucast_source: - query_res.add_ucast_question_response(*all_answers) + query_res.add_ucast_question_response(answer_set) # We always multicast as well even if its a unicast # source as long as we haven't done it recently (75% of ttl) - query_res.add_mcast_question_response(*all_answers) + query_res.add_mcast_question_response(answer_set) return query_res.outgoing_unicast(), query_res.outgoing_multicast() From 32b7dc40e2c3621fcacb2f389d51408ab35ac832 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 13:31:53 -1000 Subject: [PATCH 351/608] Fix off by 1 in test_tc_bit_defers_last_response_missing (#694) --- tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 8f459da1..4d001208 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -594,7 +594,7 @@ async def make_query(): assert timer3.cancelled() assert timer4 != timer3 - for _ in range(7): + for _ in range(8): time.sleep(0.1) if source_ip not in zc._timers: break From 5cbaa3fc02f635e6c735e1ee5f1ca19b84c0a069 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 14:04:35 -1000 Subject: [PATCH 352/608] Rollback data in one call instead of poping one byte at a time in DNSOutgoing (#696) --- zeroconf/_dns.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index b55f77f0..e285829b 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -823,7 +823,6 @@ def _write_record(self, record: DNSRecord, now: float) -> bool: else: self._write_int(record.get_remaining_ttl(now)) index = len(self.data) - self.write_short(0) # Will get replaced with the actual size record.write(self) # Adjust size for the short we will write before this record @@ -842,9 +841,7 @@ def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) return True log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) - - while len(self.data) > start_data_length: - self.data.pop() + del self.data[start_data_length:] self.size = start_size rollback_names = [name for name, idx in self.names.items() if idx >= start_size] From 767546b656d7db6df0cbf2b257953498f1bc3996 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 14:11:07 -1000 Subject: [PATCH 353/608] Use unique names in service types tests (#697) --- tests/services/test_types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 6f6645db..ba355bae 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -19,7 +19,7 @@ class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-listen-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -50,7 +50,7 @@ def test_integration_with_listener(self): @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_v6_records(self): - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-listenv6rec-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com @@ -82,7 +82,7 @@ def test_integration_with_listener_v6_records(self): @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_integration_with_listener_ipv6(self): - type_ = "_test-srvc-type._tcp.local." + type_ = "_test-listenv6ip-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) @@ -111,7 +111,7 @@ def test_integration_with_listener_ipv6(self): def test_integration_with_subtype_and_listener(self): subtype_ = "_subtype._sub" - type_ = "_type._tcp.local." + type_ = "_listen._tcp.local." name = "xxxyyy" # Note: discovery returns only DNS-SD type not subtype discovery_type = "%s.%s" % (subtype_, type_) From 26fa2fb479fff87ca5af17c2c09a557c4b6176b5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 14:25:17 -1000 Subject: [PATCH 354/608] Abstract DNSOutgoing ttl write into _write_ttl (#695) --- zeroconf/_dns.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index e285829b..1b4a9184 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -810,6 +810,10 @@ def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: else: self.write_short(record.class_) + def _write_ttl(self, record: DNSRecord, now: float) -> None: + """Write out the record ttl.""" + self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now)) + def _write_record(self, record: DNSRecord, now: float) -> bool: """Writes a record (answer, authoritative answer, additional) to the packet. Returns True on success, or False if we did not @@ -818,10 +822,7 @@ def _write_record(self, record: DNSRecord, now: float) -> bool: self.write_name(record.name) self.write_short(record.type) self._write_record_class(record) - if now == 0: - self._write_int(record.ttl) - else: - self._write_int(record.get_remaining_ttl(now)) + self._write_ttl(record, now) index = len(self.data) self.write_short(0) # Will get replaced with the actual size record.write(self) From 7e308480238fdf2cfe08474d679121e77f746fa6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 23:11:35 -1000 Subject: [PATCH 355/608] Efficiently bucket queries with known answers (#698) --- tests/test_services.py | 25 ++++++++- zeroconf/_dns.py | 29 +++++++++- zeroconf/_services/__init__.py | 96 +++++++++++++++++++++++++++------- zeroconf/aio.py | 4 +- zeroconf/const.py | 2 + 5 files changed, 133 insertions(+), 23 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 521efd1f..147c1225 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -16,7 +16,7 @@ import pytest import zeroconf as r -from zeroconf import DNSAddress, const +from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis import zeroconf._services as s from zeroconf import Zeroconf from zeroconf._services import ( @@ -1377,3 +1377,26 @@ def test_serviceinfo_accepts_bytes_or_string_dict(): addresses=addresses, ) assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + + +def test_group_ptr_queries_with_known_answers(): + questions_with_known_answers: s._QuestionWithKnownAnswers = {} + now = current_time_millis() + for i in range(120): + name = f"_hap{i}._tcp._local." + questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set( + DNSPointer( + name, + const._TYPE_PTR, + const._CLASS_IN, + 4500, + f"zoo{counter}.{name}", + ) + for counter in range(i) + ) + outs = s._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) + for out in outs: + packets = out.packets() + # If we generate multiple packets there must + # only be one question + assert len(packets) == 1 or len(out.questions) == 1 diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 1b4a9184..d6c12a71 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -34,6 +34,7 @@ _CLASSES, _CLASS_MASK, _CLASS_UNIQUE, + _DNS_PACKET_HEADER_LEN, _EXPIRE_FULL_TIME_PERCENT, _EXPIRE_STALE_TIME_PERCENT, _FLAGS_QR_MASK, @@ -54,6 +55,12 @@ _TYPE_TXT, ) +_LEN_BYTE = 1 +_LEN_SHORT = 2 +_LEN_INT = 4 + +_BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length +_NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2 if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 @@ -118,6 +125,14 @@ def answered_by(self, rec: 'DNSRecord') -> bool: and self.name == rec.name ) + def __hash__(self) -> int: + return hash((self.name, self.class_, self.type)) + + @property + def max_size(self) -> int: + """Maximum size of the question in the packet.""" + return len(self.name.encode('utf-8')) + _LEN_BYTE + _LEN_SHORT + _LEN_SHORT # type # class + @property def unicast(self) -> bool: """Returns true if the QU (not QM) is set. @@ -291,6 +306,16 @@ def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> super().__init__(name, type_, class_, ttl) self.alias = alias + @property + def max_size_compressed(self) -> int: + """Maximum size of the record in the packet assuming the name has been compressed.""" + return ( + _BASE_MAX_SIZE + + _NAME_COMPRESSION_MIN_SIZE + + (len(self.alias) - len(self.name)) + + _NAME_COMPRESSION_MIN_SIZE + ) + def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" out.write_name(self.alias) @@ -590,7 +615,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: # these 3 are per-packet -- see also _reset_for_next_packet() self.names: Dict[str, int] = {} self.data: List[bytes] = [] - self.size: int = 12 + self.size: int = _DNS_PACKET_HEADER_LEN self.allow_long: bool = True self.state = self.State.init @@ -603,7 +628,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: def _reset_for_next_packet(self) -> None: self.names = {} self.data = [] - self.size = 12 + self.size = _DNS_PACKET_HEADER_LEN self.allow_long = True def __repr__(self) -> str: diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 19fdccaf..818b3bb6 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -44,9 +44,11 @@ _CLASS_UNIQUE, _DNS_HOST_TTL, _DNS_OTHER_TTL, + _DNS_PACKET_HEADER_LEN, _EXPIRE_REFRESH_TIME_PERCENT, _FLAGS_QR_QUERY, _LISTENER_TIME, + _MAX_MSG_TYPICAL, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT, @@ -63,6 +65,9 @@ from .._core import Zeroconf # pylint: disable=cyclic-import +_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] + + @enum.unique class ServiceStateChange(enum.Enum): Added = 1 @@ -151,6 +156,67 @@ def update_records_complete(self) -> None: """ +class _DNSPointerOutgoingBucket: + """A DNSOutgoing bucket.""" + + def __init__(self, now: float, multicast: bool) -> None: + """Create a bucke to wrap a DNSOutgoing.""" + self.now = now + self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast) + self.bytes = 0 + + def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + """Add a new set of questions and known answers to the outgoing.""" + self.out.add_question(question) + for answer in answers: + self.out.add_answer_at_time(answer, self.now) + self.bytes += max_compressed_size + + +def _group_ptr_queries_with_known_answers( + now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers +) -> List[DNSOutgoing]: + """Aggregate queries so that as many known answers as possible fit in the same packet + without having known answers spill over into the next packet unless the + question and known answers are always going to exceed the packet size. + + Some responders do not implement multi-packet known answer suppression + so we try to keep all the known answers in the same packet as the + questions. + """ + # This is the maximum size the query + known answers can be with name compression. + # The actual size of the query + known answers may be a bit smaller since other + # parts may be shared when the final DNSOutgoing packets are constructed. The + # goal of this algorithm is to quickly bucket the query + known answers without + # the overhead of actually constructing the packets. + query_by_size: Dict[DNSQuestion, int] = { + question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers])) + for question, known_answers in question_with_known_answers.items() + } + max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN + query_buckets: List[_DNSPointerOutgoingBucket] = [] + for question in sorted( + query_by_size, + key=query_by_size.get, # type: ignore + reverse=True, + ): + max_compressed_size = query_by_size[question] + answers = question_with_known_answers[question] + for query_bucket in query_buckets: + if query_bucket.bytes + max_compressed_size <= max_bucket_size: + query_bucket.add(max_compressed_size, question, answers) + break + else: + # If a single question and known answers won't fit in a packet + # we will end up generating multiple packets, but there will never + # be multiple questions + query_bucket = _DNSPointerOutgoingBucket(now, multicast) + query_bucket.add(max_compressed_size, question, answers) + query_buckets.append(query_bucket) + + return [query_bucket.out for query_bucket in query_buckets] + + class _ServiceBrowserBase(RecordUpdateListener): """Base class for ServiceBrowser.""" @@ -174,9 +240,7 @@ def __init__( self.addr = addr self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - self._services = { - check_type_: {} for check_type_ in self.types - } # type: Dict[str, Dict[str, DNSRecord]] + self._services: Dict[str, Dict[str, DNSPointer]] = {check_type_: {} for check_type_ in self.types} current_time = current_time_millis() self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} @@ -317,29 +381,25 @@ def run(self) -> None: questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] self.zc.add_listener(self, questions) - def generate_ready_queries(self) -> Optional[DNSOutgoing]: + def generate_ready_queries(self) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" - out = None now = current_time_millis() if min(self._next_time.values()) > now: - return out + return [] + + questions_with_known_answers: _QuestionWithKnownAnswers = {} for type_, due in self._next_time.items(): if due > now: continue - - if out is None: - out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast) - out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)) - - for record in self._services[type_].values(): - if not record.is_stale(now): - out.add_answer_at_time(record, now) - + questions_with_known_answers[DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)] = set( + record for record in self._services[type_].values() if not record.is_stale(now) + ) self._next_time[type_] = now + self._delay[type_] self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - return out + + return _group_ptr_queries_with_known_answers(now, self.multicast, questions_with_known_answers) def _seconds_to_wait(self) -> Optional[float]: """Returns the number of seconds to wait for the next event.""" @@ -406,8 +466,8 @@ def run(self) -> None: if self.zc.done or self.done: return - out = self.generate_ready_queries() - if out: + outs = self.generate_ready_queries() + for out in outs: self.zc.send(out, addr=self.addr, port=self.port) if not self._handlers_to_call: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 01211bb4..ae57d014 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -159,8 +159,8 @@ async def async_run(self) -> None: if not self._handlers_to_call: await wait_condition_or_timeout(self.aiozc.condition, timeout) - out = self.generate_ready_queries() - if out: + outs = self.generate_ready_queries() + for out in outs: self.aiozc.zeroconf.async_send(out, addr=self.addr, port=self.port) if not self._handlers_to_call: diff --git a/zeroconf/const.py b/zeroconf/const.py index 96f536df..ba9d5309 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -47,6 +47,8 @@ _DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 _DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 +_DNS_PACKET_HEADER_LEN = 12 + _MAX_MSG_TYPICAL = 1460 # unused _MAX_MSG_ABSOLUTE = 8966 From c368e1c67c82598e920ca52b1f7a47ed6e1cf738 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Jun 2021 23:15:42 -1000 Subject: [PATCH 356/608] Update changelog (#699) --- README.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/README.rst b/README.rst index 06a6c64e..eb730cdf 100644 --- a/README.rst +++ b/README.rst @@ -201,6 +201,12 @@ Changelog * TRAFFIC REDUCTION: Avoid including additionals when the answer is suppressed by known-answer supression (#614) @bdraco +* TRAFFIC REDUCTION: Implement multi-packet known answer supression (#687) @bdraco + + Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 + +* TRAFFIC REDUCTION: Efficiently bucket queries with known answers (#698) @bdraco + * MAJOR BUG: Ensure matching PTR queries are returned with the ANY query (#618) @bdraco * MAJOR BUG: Fix lookup of uppercase names in registry (#597) @bdraco @@ -219,6 +225,16 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Abstract DNSOutgoing ttl write into _write_ttl (#695) @bdraco + +* Rollback data in one call instead of poping one byte at a time in DNS Outgoing (#696) @bdraco + +* Suppress additionals when answer is suppressed (#690) @bdraco + +* Move setting DNS created and ttl into its own function (#692) @bdraco + +* Add truncated property to DNSMessage to lookup the TC bit (#686) @bdraco + * Check if SO_REUSEPORT exists instead of using an exception catch (#682) @bdraco * Use DNSRRSet for known answer suppression (#680) @bdraco From f39bde0f6cba7a3c1b8fe8bc1a4ab4388801e486 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 09:11:34 -1000 Subject: [PATCH 357/608] Split DNSOutgoing/DNSIncoming/DNSMessage into zeroconf._protocol (#705) --- tests/test_dns.py | 695 +------------------------------ tests/test_protocol.py | 722 +++++++++++++++++++++++++++++++++ zeroconf/__init__.py | 4 +- zeroconf/_core.py | 3 +- zeroconf/_dns.py | 620 +--------------------------- zeroconf/_handlers.py | 3 +- zeroconf/_protocol.py | 644 +++++++++++++++++++++++++++++ zeroconf/_services/__init__.py | 3 +- 8 files changed, 1379 insertions(+), 1315 deletions(-) create mode 100644 tests/test_protocol.py create mode 100644 zeroconf/_protocol.py diff --git a/tests/test_dns.py b/tests/test_dns.py index 4096aa94..557802e1 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -2,19 +2,16 @@ # -*- coding: utf-8 -*- -""" Unit tests for zeroconf.py """ +""" Unit tests for zeroconf._dns. """ -import copy import logging import socket -import struct import time import unittest import unittest.mock -from typing import Dict, cast # noqa # used in type hints import zeroconf as r -from zeroconf import DNSIncoming, const, current_time_millis +from zeroconf import const, current_time_millis from zeroconf._dns import DNSRRSet from zeroconf import ( DNSHinfo, @@ -162,454 +159,6 @@ def test_dns_record_is_recent(self): assert record.is_recent(now + (8 * 1000)) is False -class PacketGeneration(unittest.TestCase): - def test_parse_own_packet_simple(self): - generated = r.DNSOutgoing(0) - r.DNSIncoming(generated.packets()[0]) - - def test_parse_own_packet_simple_unicast(self): - generated = r.DNSOutgoing(0, False) - r.DNSIncoming(generated.packets()[0]) - - def test_parse_own_packet_flags(self): - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - r.DNSIncoming(generated.packets()[0]) - - def test_parse_own_packet_question(self): - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) - r.DNSIncoming(generated.packets()[0]) - - def test_parse_own_packet_response(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSService( - "æøå.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - 0, - ) - parsed = r.DNSIncoming(generated.packets()[0]) - assert len(generated.answers) == 1 - assert len(generated.answers) == len(parsed.answers) - - def test_adding_empty_answer(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - None, - 0, - ) - generated.add_answer_at_time( - r.DNSService( - "æøå.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - 0, - ) - parsed = r.DNSIncoming(generated.packets()[0]) - assert len(generated.answers) == 1 - assert len(generated.answers) == len(parsed.answers) - - def test_adding_expired_answer(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSService( - "æøå.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - current_time_millis() + 1000000, - ) - parsed = r.DNSIncoming(generated.packets()[0]) - assert len(generated.answers) == 0 - assert len(generated.answers) == len(parsed.answers) - - def test_match_question(self): - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) - generated.add_question(question) - parsed = r.DNSIncoming(generated.packets()[0]) - assert len(generated.questions) == 1 - assert len(generated.questions) == len(parsed.questions) - assert question == parsed.questions[0] - - def test_suppress_answer(self): - query_generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) - query_generated.add_question(question) - answer1 = r.DNSService( - "testname1.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ) - staleanswer2 = r.DNSService( - "testname2.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL / 2, - 0, - 0, - 80, - "foo.local.", - ) - answer2 = r.DNSService( - "testname2.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ) - query_generated.add_answer_at_time(answer1, 0) - query_generated.add_answer_at_time(staleanswer2, 0) - query = r.DNSIncoming(query_generated.packets()[0]) - - # Should be suppressed - response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - response.add_answer(query, answer1) - assert len(response.answers) == 0 - - # Should not be suppressed, TTL in query is too short - response.add_answer(query, answer2) - assert len(response.answers) == 1 - - # Should not be suppressed, name is different - tmp = copy.copy(answer1) - tmp.key = "testname3.local." - tmp.name = "testname3.local." - response.add_answer(query, tmp) - assert len(response.answers) == 2 - - # Should not be suppressed, type is different - tmp = copy.copy(answer1) - tmp.type = const._TYPE_A - response.add_answer(query, tmp) - assert len(response.answers) == 3 - - # Should not be suppressed, class is different - tmp = copy.copy(answer1) - tmp.class_ = const._CLASS_NONE - response.add_answer(query, tmp) - assert len(response.answers) == 4 - - # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService - - def test_dns_hinfo(self): - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) - parsed = r.DNSIncoming(generated.packets()[0]) - answer = cast(r.DNSHinfo, parsed.answers[0]) - assert answer.cpu == u'cpu' - assert answer.os == u'os' - - generated = r.DNSOutgoing(0) - generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) - self.assertRaises(r.NamePartTooLongException, generated.packets) - - def test_many_questions(self): - """Test many questions get seperated into multiple packets.""" - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - questions = [] - for i in range(100): - question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) - generated.add_question(question) - questions.append(question) - assert len(generated.questions) == 100 - - packets = generated.packets() - assert len(packets) == 2 - assert len(packets[0]) < const._MAX_MSG_TYPICAL - assert len(packets[1]) < const._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert len(parsed1.questions) == 85 - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 15 - - def test_many_questions_with_many_known_answers(self): - """Test many questions and known answers get seperated into multiple packets.""" - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - questions = [] - for _ in range(30): - question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) - generated.add_question(question) - questions.append(question) - assert len(generated.questions) == 30 - now = current_time_millis() - for _ in range(200): - known_answer = r.DNSPointer( - "myservice{i}_tcp._tcp.local.", - const._TYPE_PTR, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - '123.local.', - ) - generated.add_answer_at_time(known_answer, now) - packets = generated.packets() - assert len(packets) == 3 - assert len(packets[0]) <= const._MAX_MSG_TYPICAL - assert len(packets[1]) <= const._MAX_MSG_TYPICAL - assert len(packets[2]) <= const._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert len(parsed1.questions) == 30 - assert len(parsed1.answers) == 88 - assert parsed1.truncated - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 0 - assert len(parsed2.answers) == 101 - assert parsed2.truncated - parsed3 = r.DNSIncoming(packets[2]) - assert len(parsed3.questions) == 0 - assert len(parsed3.answers) == 11 - assert not parsed3.truncated - - def test_massive_probe_packet_split(self): - """Test probe with many authorative answers.""" - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) - questions = [] - for _ in range(30): - question = r.DNSQuestion( - f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE - ) - generated.add_question(question) - questions.append(question) - assert len(generated.questions) == 30 - now = current_time_millis() - for _ in range(200): - authorative_answer = r.DNSPointer( - "myservice{i}_tcp._tcp.local.", - const._TYPE_PTR, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - '123.local.', - ) - generated.add_authorative_answer(authorative_answer) - packets = generated.packets() - assert len(packets) == 3 - assert len(packets[0]) <= const._MAX_MSG_TYPICAL - assert len(packets[1]) <= const._MAX_MSG_TYPICAL - assert len(packets[2]) <= const._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert parsed1.questions[0].unicast is True - assert len(parsed1.questions) == 30 - assert parsed1.num_authorities == 88 - assert parsed1.truncated - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 0 - assert parsed2.num_authorities == 101 - assert parsed2.truncated - parsed3 = r.DNSIncoming(packets[2]) - assert len(parsed3.questions) == 0 - assert parsed3.num_authorities == 11 - assert not parsed3.truncated - - def test_only_one_answer_can_by_large(self): - """Test that only the first answer in each packet can be large. - - https://datatracker.ietf.org/doc/html/rfc6762#section-17 - """ - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) - for i in range(3): - generated.add_answer( - query, - r.DNSText( - "zoom._hap._tcp.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - 1200, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, - ), - ) - generated.add_answer( - query, - r.DNSService( - "testname1.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - "foo.local.", - ), - ) - assert len(generated.answers) == 4 - - packets = generated.packets() - assert len(packets) == 4 - assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE - assert len(packets[0]) > const._MAX_MSG_TYPICAL - - assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE - assert len(packets[1]) > const._MAX_MSG_TYPICAL - - assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE - assert len(packets[2]) > const._MAX_MSG_TYPICAL - - assert len(packets[3]) <= const._MAX_MSG_TYPICAL - - for packet in packets: - parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 1 - - def test_questions_do_not_end_up_every_packet(self): - """Test that questions are not sent again when multiple packets are needed. - - https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - Sometimes a Multicast DNS querier will already have too many answers - to fit in the Known-Answer Section of its query packets.... It MUST - immediately follow the packet with another query packet containing no - questions and as many more Known-Answer records as will fit. - """ - - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - for i in range(35): - question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) - generated.add_question(question) - answer = r.DNSService( - f"testname{i}.local.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - 80, - f"foo{i}.local.", - ) - generated.add_answer_at_time(answer, 0) - - assert len(generated.questions) == 35 - assert len(generated.answers) == 35 - - packets = generated.packets() - assert len(packets) == 2 - assert len(packets[0]) <= const._MAX_MSG_TYPICAL - assert len(packets[1]) <= const._MAX_MSG_TYPICAL - - parsed1 = r.DNSIncoming(packets[0]) - assert len(parsed1.questions) == 35 - assert len(parsed1.answers) == 33 - - parsed2 = r.DNSIncoming(packets[1]) - assert len(parsed2.questions) == 0 - assert len(parsed2.answers) == 2 - - -class PacketForm(unittest.TestCase): - def test_transaction_id(self): - """ID must be zero in a DNS-SD packet""" - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - bytes = generated.packets()[0] - id = bytes[0] << 8 | bytes[1] - assert id == 0 - - def test_setting_id(self): - """Test setting id in the constructor""" - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY, id_=4444) - assert generated.id == 4444 - - def test_query_header_bits(self): - generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) - bytes = generated.packets()[0] - flags = bytes[2] << 8 | bytes[3] - assert flags == 0x0 - - def test_response_header_bits(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - bytes = generated.packets()[0] - flags = bytes[2] << 8 | bytes[3] - assert flags == 0x8000 - - def test_numbers(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - bytes = generated.packets()[0] - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - assert num_questions == 0 - assert num_answers == 0 - assert num_authorities == 0 - assert num_additionals == 0 - - def test_numbers_questions(self): - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) - for i in range(10): - generated.add_question(question) - bytes = generated.packets()[0] - (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) - assert num_questions == 10 - assert num_answers == 0 - assert num_authorities == 0 - assert num_additionals == 0 - - -class TestDnsIncoming(unittest.TestCase): - def test_incoming_exception_handling(self): - generated = r.DNSOutgoing(0) - packet = generated.packets()[0] - packet = packet[:8] + b'deadbeef' + packet[8:] - parsed = r.DNSIncoming(packet) - parsed = r.DNSIncoming(packet) - assert parsed.valid is False - - def test_incoming_unknown_type(self): - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') - generated.add_additional_answer(answer) - packet = generated.packets()[0] - parsed = r.DNSIncoming(packet) - assert len(parsed.answers) == 0 - assert parsed.is_query() != parsed.is_response() - - def test_incoming_circular_reference(self): - assert not r.DNSIncoming( - bytes.fromhex( - '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb' - '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e' - '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966' - '792d636f6e6e656374045f746370c023' - ) - ).valid - - def test_incoming_ipv6(self): - addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com - packed = socket.inet_pton(socket.AF_INET6, addr) - generated = r.DNSOutgoing(0) - answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) - generated.add_additional_answer(answer) - packet = generated.packets()[0] - parsed = r.DNSIncoming(packet) - record = parsed.answers[0] - assert isinstance(record, r.DNSAddress) - assert record.address == packed - - class TestDNSCache(unittest.TestCase): def test_order(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') @@ -647,246 +196,6 @@ def test_cache_empty_multiple_calls_does_not_throw(self): assert 'a' not in cache.cache -def test_dns_compression_rollback_for_corruption(): - """Verify rolling back does not lead to dns compression corruption.""" - out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) - address = socket.inet_pton(socket.AF_INET, "192.168.208.5") - - additionals = [ - { - "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", - "address": address, - "port": 51832, - "text": b"\x13md=HASS Bridge" - b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=L0m/aQ==", - }, - { - "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", - "address": address, - "port": 51834, - "text": b"\x13md=HASS Bridge" - b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=b2CnzQ==", - }, - { - "name": "Master Bed TV CEDB27._hap._tcp.local.", - "address": address, - "port": 51830, - "text": b"\x10md=Master Bed" - b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=CVj1kw==", - }, - { - "name": "Living Room TV 921B77._hap._tcp.local.", - "address": address, - "port": 51833, - "text": b"\x11md=Living Room" - b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=qU77SQ==", - }, - { - "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", - "address": address, - "port": 51829, - "text": b"\x13md=HASS Bridge" - b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=b0QZlg==", - }, - { - "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", - "address": address, - "port": 51837, - "text": b"\x13md=HASS Bridge" - b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=ahAISA==", - }, - { - "name": "FrontdoorCamera 8941D1._hap._tcp.local.", - "address": address, - "port": 54898, - "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" - b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", - }, - { - "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", - "address": address, - "port": 51836, - "text": b"\x13md=HASS Bridge" - b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=6fLM5A==", - }, - { - "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", - "address": address, - "port": 51838, - "text": b"\x13md=HASS Bridge" - b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" - b"ci=2\x04sf=0\x0bsh=u3bdfw==", - }, - { - "name": "Snooze Room TV 6B89B0._hap._tcp.local.", - "address": address, - "port": 51835, - "text": b"\x11md=Snooze Room" - b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" - b"ci=31\x04sf=0\x0bsh=xNTqsg==", - }, - { - "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", - "address": address, - "port": 54811, - "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" - b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", - }, - { - "name": "HASS Bridge OS95 39C053._hap._tcp.local.", - "address": address, - "port": 51831, - "text": b"\x13md=HASS Bridge" - b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" - b"\x04sf=0\x0bsh=Xfe5LQ==", - }, - ] - - out.add_answer_at_time( - DNSText( - "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - for record in additionals: - out.add_additional_answer( - r.DNSService( - record["name"], # type: ignore - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - 0, - 0, - record["port"], # type: ignore - record["name"], # type: ignore - ) - ) - out.add_additional_answer( - r.DNSText( - record["name"], # type: ignore - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - record["text"], # type: ignore - ) - ) - out.add_additional_answer( - r.DNSAddress( - record["name"], # type: ignore - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_HOST_TTL, - record["address"], # type: ignore - ) - ) - - for packet in out.packets(): - # Verify we can process the packets we created to - # ensure there is no corruption with the dns compression - incoming = r.DNSIncoming(packet) - assert incoming.valid is True - - -def test_tc_bit_in_query_packet(): - """Verify the TC bit is set when known answers exceed the packet size.""" - out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) - type_ = "_hap._tcp.local." - out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) - - for i in range(30): - out.add_answer_at_time( - DNSText( - ("HASS Bridge W9DN %s._hap._tcp.local." % i), - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - packets = out.packets() - assert len(packets) == 3 - - first_packet = r.DNSIncoming(packets[0]) - assert first_packet.truncated - assert first_packet.valid is True - - second_packet = r.DNSIncoming(packets[1]) - assert second_packet.truncated - assert second_packet.valid is True - - third_packet = r.DNSIncoming(packets[2]) - assert not third_packet.truncated - assert third_packet.valid is True - - -def test_tc_bit_not_set_in_answer_packet(): - """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" - out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) - for i in range(30): - out.add_answer_at_time( - DNSText( - ("HASS Bridge W9DN %s._hap._tcp.local." % i), - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - const._DNS_OTHER_TTL, - b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - 0, - ) - - packets = out.packets() - assert len(packets) == 3 - - first_packet = r.DNSIncoming(packets[0]) - assert not first_packet.truncated - assert first_packet.valid is True - - second_packet = r.DNSIncoming(packets[1]) - assert not second_packet.truncated - assert second_packet.valid is True - - third_packet = r.DNSIncoming(packets[2]) - assert not third_packet.truncated - assert third_packet.valid is True - - -# 4003 15.973052 192.168.107.68 224.0.0.251 MDNS 76 Standard query 0xffc4 PTR _raop._tcp.local, "QM" question -def test_qm_packet_parser(): - """Test we can parse a query packet with the QM bit.""" - qm_packet = ( - b'\xff\xc4\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x05_raop\x04_tcp\x05local\x00\x00\x0c\x00\x01' - ) - parsed = DNSIncoming(qm_packet) - assert parsed.questions[0].unicast is False - assert ",QM," in str(parsed.questions[0]) - - -# 389951 1450.577370 192.168.107.111 224.0.0.251 MDNS 115 Standard query 0x0000 PTR _companion-link._tcp.local, "QU" question OPT -def test_qu_packet_parser(): - """Test we can parse a query packet with the QU bit.""" - qu_packet = b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x0f_companion-link\x04_tcp\x05local\x00\x00\x0c\x80\x01\x00\x00)\x05\xa0\x00\x00\x11\x94\x00\x12\x00\x04\x00\x0e\x00dz{\x8a6\x9czF\x84,\xcaQ\xff' - parsed = DNSIncoming(qu_packet) - assert parsed.questions[0].unicast is True - assert ",QU," in str(parsed.questions[0]) - - def test_dns_record_hashablity_does_not_consider_ttl(): """Test DNSRecord are hashable.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..8b4ebd04 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,722 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._protocol """ + +import copy +import logging +import socket +import struct +import unittest +import unittest.mock +from typing import cast + +import zeroconf as r +from zeroconf import DNSIncoming, const, current_time_millis +from zeroconf import ( + DNSHinfo, + DNSText, +) + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class PacketGeneration(unittest.TestCase): + def test_parse_own_packet_simple(self): + generated = r.DNSOutgoing(0) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_simple_unicast(self): + generated = r.DNSOutgoing(0, False) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_flags(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_question(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) + r.DNSIncoming(generated.packets()[0]) + + def test_parse_own_packet_response(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_adding_empty_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + None, + 0, + ) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + 0, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 1 + assert len(generated.answers) == len(parsed.answers) + + def test_adding_expired_answer(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSService( + "æøå.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + current_time_millis() + 1000000, + ) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.answers) == 0 + assert len(generated.answers) == len(parsed.answers) + + def test_match_question(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + parsed = r.DNSIncoming(generated.packets()[0]) + assert len(generated.questions) == 1 + assert len(generated.questions) == len(parsed.questions) + assert question == parsed.questions[0] + + def test_suppress_answer(self): + query_generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + query_generated.add_question(question) + answer1 = r.DNSService( + "testname1.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + staleanswer2 = r.DNSService( + "testname2.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL / 2, + 0, + 0, + 80, + "foo.local.", + ) + answer2 = r.DNSService( + "testname2.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ) + query_generated.add_answer_at_time(answer1, 0) + query_generated.add_answer_at_time(staleanswer2, 0) + query = r.DNSIncoming(query_generated.packets()[0]) + + # Should be suppressed + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer(query, answer1) + assert len(response.answers) == 0 + + # Should not be suppressed, TTL in query is too short + response.add_answer(query, answer2) + assert len(response.answers) == 1 + + # Should not be suppressed, name is different + tmp = copy.copy(answer1) + tmp.key = "testname3.local." + tmp.name = "testname3.local." + response.add_answer(query, tmp) + assert len(response.answers) == 2 + + # Should not be suppressed, type is different + tmp = copy.copy(answer1) + tmp.type = const._TYPE_A + response.add_answer(query, tmp) + assert len(response.answers) == 3 + + # Should not be suppressed, class is different + tmp = copy.copy(answer1) + tmp.class_ = const._CLASS_NONE + response.add_answer(query, tmp) + assert len(response.answers) == 4 + + # ::TODO:: could add additional tests for DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService + + def test_dns_hinfo(self): + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) + parsed = r.DNSIncoming(generated.packets()[0]) + answer = cast(r.DNSHinfo, parsed.answers[0]) + assert answer.cpu == u'cpu' + assert answer.os == u'os' + + generated = r.DNSOutgoing(0) + generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) + self.assertRaises(r.NamePartTooLongException, generated.packets) + + def test_many_questions(self): + """Test many questions get seperated into multiple packets.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + questions = [] + for i in range(100): + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 100 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) < const._MAX_MSG_TYPICAL + assert len(packets[1]) < const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 85 + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 15 + + def test_many_questions_with_many_known_answers(self): + """Test many questions and known answers get seperated into multiple packets.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + questions = [] + for _ in range(30): + question = r.DNSQuestion(f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + known_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_answer_at_time(known_answer, now) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 30 + assert len(parsed1.answers) == 88 + assert parsed1.truncated + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 101 + assert parsed2.truncated + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert len(parsed3.answers) == 11 + assert not parsed3.truncated + + def test_massive_probe_packet_split(self): + """Test probe with many authorative answers.""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + questions = [] + for _ in range(30): + question = r.DNSQuestion( + f"_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN | const._CLASS_UNIQUE + ) + generated.add_question(question) + questions.append(question) + assert len(generated.questions) == 30 + now = current_time_millis() + for _ in range(200): + authorative_answer = r.DNSPointer( + "myservice{i}_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + '123.local.', + ) + generated.add_authorative_answer(authorative_answer) + packets = generated.packets() + assert len(packets) == 3 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + assert len(packets[2]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert parsed1.questions[0].unicast is True + assert len(parsed1.questions) == 30 + assert parsed1.num_authorities == 88 + assert parsed1.truncated + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert parsed2.num_authorities == 101 + assert parsed2.truncated + parsed3 = r.DNSIncoming(packets[2]) + assert len(parsed3.questions) == 0 + assert parsed3.num_authorities == 11 + assert not parsed3.truncated + + def test_only_one_answer_can_by_large(self): + """Test that only the first answer in each packet can be large. + + https://datatracker.ietf.org/doc/html/rfc6762#section-17 + """ + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + query = r.DNSIncoming(r.DNSOutgoing(const._FLAGS_QR_QUERY).packets()[0]) + for i in range(3): + generated.add_answer( + query, + r.DNSText( + "zoom._hap._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 1200, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==' * 100, + ), + ) + generated.add_answer( + query, + r.DNSService( + "testname1.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + "foo.local.", + ), + ) + assert len(generated.answers) == 4 + + packets = generated.packets() + assert len(packets) == 4 + assert len(packets[0]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[0]) > const._MAX_MSG_TYPICAL + + assert len(packets[1]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[1]) > const._MAX_MSG_TYPICAL + + assert len(packets[2]) <= const._MAX_MSG_ABSOLUTE + assert len(packets[2]) > const._MAX_MSG_TYPICAL + + assert len(packets[3]) <= const._MAX_MSG_TYPICAL + + for packet in packets: + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + def test_questions_do_not_end_up_every_packet(self): + """Test that questions are not sent again when multiple packets are needed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + Sometimes a Multicast DNS querier will already have too many answers + to fit in the Known-Answer Section of its query packets.... It MUST + immediately follow the packet with another query packet containing no + questions and as many more Known-Answer records as will fit. + """ + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + for i in range(35): + question = r.DNSQuestion(f"testname{i}.local.", const._TYPE_SRV, const._CLASS_IN) + generated.add_question(question) + answer = r.DNSService( + f"testname{i}.local.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + 80, + f"foo{i}.local.", + ) + generated.add_answer_at_time(answer, 0) + + assert len(generated.questions) == 35 + assert len(generated.answers) == 35 + + packets = generated.packets() + assert len(packets) == 2 + assert len(packets[0]) <= const._MAX_MSG_TYPICAL + assert len(packets[1]) <= const._MAX_MSG_TYPICAL + + parsed1 = r.DNSIncoming(packets[0]) + assert len(parsed1.questions) == 35 + assert len(parsed1.answers) == 33 + + parsed2 = r.DNSIncoming(packets[1]) + assert len(parsed2.questions) == 0 + assert len(parsed2.answers) == 2 + + +class PacketForm(unittest.TestCase): + def test_transaction_id(self): + """ID must be zero in a DNS-SD packet""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + bytes = generated.packets()[0] + id = bytes[0] << 8 | bytes[1] + assert id == 0 + + def test_setting_id(self): + """Test setting id in the constructor""" + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY, id_=4444) + assert generated.id == 4444 + + def test_query_header_bits(self): + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + bytes = generated.packets()[0] + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x0 + + def test_response_header_bits(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + bytes = generated.packets()[0] + flags = bytes[2] << 8 | bytes[3] + assert flags == 0x8000 + + def test_numbers(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + bytes = generated.packets()[0] + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 0 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 + + def test_numbers_questions(self): + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + question = r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN) + for i in range(10): + generated.add_question(question) + bytes = generated.packets()[0] + (num_questions, num_answers, num_authorities, num_additionals) = struct.unpack('!4H', bytes[4:12]) + assert num_questions == 10 + assert num_answers == 0 + assert num_authorities == 0 + assert num_additionals == 0 + + +class TestDnsIncoming(unittest.TestCase): + def test_incoming_exception_handling(self): + generated = r.DNSOutgoing(0) + packet = generated.packets()[0] + packet = packet[:8] + b'deadbeef' + packet[8:] + parsed = r.DNSIncoming(packet) + parsed = r.DNSIncoming(packet) + assert parsed.valid is False + + def test_incoming_unknown_type(self): + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + generated.add_additional_answer(answer) + packet = generated.packets()[0] + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + assert parsed.is_query() != parsed.is_response() + + def test_incoming_circular_reference(self): + assert not r.DNSIncoming( + bytes.fromhex( + '01005e0000fb542a1bf0577608004500006897934000ff11d81bc0a86a31e00000fb' + '14e914e90054f9b2000084000000000100000000095f7365727669636573075f646e' + '732d7364045f756470056c6f63616c00000c0001000011940018105f73706f746966' + '792d636f6e6e656374045f746370c023' + ) + ).valid + + def test_incoming_ipv6(self): + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com + packed = socket.inet_pton(socket.AF_INET6, addr) + generated = r.DNSOutgoing(0) + answer = r.DNSAddress('domain', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) + generated.add_additional_answer(answer) + packet = generated.packets()[0] + parsed = r.DNSIncoming(packet) + record = parsed.answers[0] + assert isinstance(record, r.DNSAddress) + assert record.address == packed + + +def test_dns_compression_rollback_for_corruption(): + """Verify rolling back does not lead to dns compression corruption.""" + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + address = socket.inet_pton(socket.AF_INET, "192.168.208.5") + + additionals = [ + { + "name": "HASS Bridge ZJWH FF5137._hap._tcp.local.", + "address": address, + "port": 51832, + "text": b"\x13md=HASS Bridge" + b" ZJWH\x06pv=1.0\x14id=01:6B:30:FF:51:37\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=L0m/aQ==", + }, + { + "name": "HASS Bridge 3K9A C2582A._hap._tcp.local.", + "address": address, + "port": 51834, + "text": b"\x13md=HASS Bridge" + b" 3K9A\x06pv=1.0\x14id=E2:AA:5B:C2:58:2A\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b2CnzQ==", + }, + { + "name": "Master Bed TV CEDB27._hap._tcp.local.", + "address": address, + "port": 51830, + "text": b"\x10md=Master Bed" + b" TV\x06pv=1.0\x14id=9E:B7:44:CE:DB:27\x05c#=18\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=CVj1kw==", + }, + { + "name": "Living Room TV 921B77._hap._tcp.local.", + "address": address, + "port": 51833, + "text": b"\x11md=Living Room" + b" TV\x06pv=1.0\x14id=11:61:E7:92:1B:77\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=qU77SQ==", + }, + { + "name": "HASS Bridge ZC8X FF413D._hap._tcp.local.", + "address": address, + "port": 51829, + "text": b"\x13md=HASS Bridge" + b" ZC8X\x06pv=1.0\x14id=96:14:45:FF:41:3D\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=b0QZlg==", + }, + { + "name": "HASS Bridge WLTF 4BE61F._hap._tcp.local.", + "address": address, + "port": 51837, + "text": b"\x13md=HASS Bridge" + b" WLTF\x06pv=1.0\x14id=E0:E7:98:4B:E6:1F\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=ahAISA==", + }, + { + "name": "FrontdoorCamera 8941D1._hap._tcp.local.", + "address": address, + "port": 54898, + "text": b"\x12md=FrontdoorCamera\x06pv=1.0\x14id=9F:B7:DC:89:41:D1\x04c#=2\x04" + b"s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=0+MXmA==", + }, + { + "name": "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + "address": address, + "port": 51836, + "text": b"\x13md=HASS Bridge" + b" W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=6fLM5A==", + }, + { + "name": "HASS Bridge Y9OO EFF0A7._hap._tcp.local.", + "address": address, + "port": 51838, + "text": b"\x13md=HASS Bridge" + b" Y9OO\x06pv=1.0\x14id=D3:FE:98:EF:F0:A7\x04c#=2\x04s#=1\x04ff=0\x04" + b"ci=2\x04sf=0\x0bsh=u3bdfw==", + }, + { + "name": "Snooze Room TV 6B89B0._hap._tcp.local.", + "address": address, + "port": 51835, + "text": b"\x11md=Snooze Room" + b" TV\x06pv=1.0\x14id=5F:D5:70:6B:89:B0\x05c#=17\x04s#=1\x04ff=0\x05" + b"ci=31\x04sf=0\x0bsh=xNTqsg==", + }, + { + "name": "AlexanderHomeAssistant 74651D._hap._tcp.local.", + "address": address, + "port": 54811, + "text": b"\x19md=AlexanderHomeAssistant\x06pv=1.0\x14id=59:8A:0B:74:65:1D\x05" + b"c#=14\x04s#=1\x04ff=0\x04ci=2\x04sf=0\x0bsh=ccZLPA==", + }, + { + "name": "HASS Bridge OS95 39C053._hap._tcp.local.", + "address": address, + "port": 51831, + "text": b"\x13md=HASS Bridge" + b" OS95\x06pv=1.0\x14id=7E:8C:E6:39:C0:53\x05c#=12\x04s#=1\x04ff=0\x04ci=2" + b"\x04sf=0\x0bsh=Xfe5LQ==", + }, + ] + + out.add_answer_at_time( + DNSText( + "HASS Bridge W9DN 5B5CC5._hap._tcp.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for record in additionals: + out.add_additional_answer( + r.DNSService( + record["name"], # type: ignore + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + 0, + 0, + record["port"], # type: ignore + record["name"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSText( + record["name"], # type: ignore + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + record["text"], # type: ignore + ) + ) + out.add_additional_answer( + r.DNSAddress( + record["name"], # type: ignore + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_HOST_TTL, + record["address"], # type: ignore + ) + ) + + for packet in out.packets(): + # Verify we can process the packets we created to + # ensure there is no corruption with the dns compression + incoming = r.DNSIncoming(packet) + assert incoming.valid is True + + +def test_tc_bit_in_query_packet(): + """Verify the TC bit is set when known answers exceed the packet size.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert first_packet.truncated + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert second_packet.truncated + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert not third_packet.truncated + assert third_packet.valid is True + + +def test_tc_bit_not_set_in_answer_packet(): + """Verify the TC bit is not set when there are no questions and answers exceed the packet size.""" + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + packets = out.packets() + assert len(packets) == 3 + + first_packet = r.DNSIncoming(packets[0]) + assert not first_packet.truncated + assert first_packet.valid is True + + second_packet = r.DNSIncoming(packets[1]) + assert not second_packet.truncated + assert second_packet.valid is True + + third_packet = r.DNSIncoming(packets[2]) + assert not third_packet.truncated + assert third_packet.valid is True + + +# 4003 15.973052 192.168.107.68 224.0.0.251 MDNS 76 Standard query 0xffc4 PTR _raop._tcp.local, "QM" question +def test_qm_packet_parser(): + """Test we can parse a query packet with the QM bit.""" + qm_packet = ( + b'\xff\xc4\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x05_raop\x04_tcp\x05local\x00\x00\x0c\x00\x01' + ) + parsed = DNSIncoming(qm_packet) + assert parsed.questions[0].unicast is False + assert ",QM," in str(parsed.questions[0]) + + +# 389951 1450.577370 192.168.107.111 224.0.0.251 MDNS 115 Standard query 0x0000 PTR _companion-link._tcp.local, "QU" question OPT +def test_qu_packet_parser(): + """Test we can parse a query packet with the QU bit.""" + qu_packet = b'\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x01\x0f_companion-link\x04_tcp\x05local\x00\x00\x0c\x80\x01\x00\x00)\x05\xa0\x00\x00\x11\x94\x00\x12\x00\x04\x00\x0e\x00dz{\x8a6\x9czF\x84,\xcaQ\xff' + parsed = DNSIncoming(qu_packet) + assert parsed.questions[0].unicast is True + assert ",QU," in str(parsed.questions[0]) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 02d2afa1..ab2b0993 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -28,8 +28,6 @@ DNSAddress, DNSEntry, DNSHinfo, - DNSIncoming, - DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, @@ -46,6 +44,7 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) +from ._protocol import DNSIncoming, DNSOutgoing # noqa # import needed for backwards compat from ._services import ( # noqa # import needed for backwards compat instance_name_from_service_info, Signal, @@ -81,6 +80,7 @@ __all__ = [ "__version__", + "DNSOutgoing", "Zeroconf", "ServiceInfo", "ServiceBrowser", diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 6aef35a3..f4fb647a 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -32,10 +32,11 @@ from typing import Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache -from ._dns import DNSIncoming, DNSOutgoing, DNSQuestion +from ._dns import DNSQuestion from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log +from ._protocol import DNSIncoming, DNSOutgoing from ._services import ( RecordUpdateListener, ServiceBrowser, diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index d6c12a71..9a6ef73d 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -20,39 +20,21 @@ USA """ -import enum import socket -import struct -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union, cast -from ._exceptions import AbstractMethodException, IncomingDecodeError, NamePartTooLongException -from ._logger import QuietLogger, log +from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address -from ._utils.struct import int2byte from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CLASSES, _CLASS_MASK, _CLASS_UNIQUE, - _DNS_PACKET_HEADER_LEN, _EXPIRE_FULL_TIME_PERCENT, _EXPIRE_STALE_TIME_PERCENT, - _FLAGS_QR_MASK, - _FLAGS_QR_QUERY, - _FLAGS_QR_RESPONSE, - _FLAGS_TC, - _MAX_MSG_ABSOLUTE, - _MAX_MSG_TYPICAL, _RECENT_TIME_PERCENT, _TYPES, - _TYPE_A, - _TYPE_AAAA, _TYPE_ANY, - _TYPE_CNAME, - _TYPE_HINFO, - _TYPE_PTR, - _TYPE_SRV, - _TYPE_TXT, ) _LEN_BYTE = 1 @@ -64,7 +46,7 @@ if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 - from ._cache import DNSCache # pylint: disable=cyclic-import + from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import class DNSEntry: @@ -409,602 +391,6 @@ def __repr__(self) -> str: return self.to_string("%s:%s" % (self.server, self.port)) -class DNSMessage: - """A base class for DNS messages.""" - - def __init__(self, flags: int) -> None: - """Construct a DNS message.""" - self.flags = flags - - def is_query(self) -> bool: - """Returns true if this is a query.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY - - def is_response(self) -> bool: - """Returns true if this is a response.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE - - @property - def truncated(self) -> bool: - """Returns true if this is a truncated.""" - return (self.flags & _FLAGS_TC) == _FLAGS_TC - - -class DNSIncoming(DNSMessage, QuietLogger): - - """Object representation of an incoming DNS packet""" - - def __init__(self, data: bytes) -> None: - """Constructor from string holding bytes of packet""" - super().__init__(0) - self.offset = 0 - self.data = data - self.questions = [] # type: List[DNSQuestion] - self.answers = [] # type: List[DNSRecord] - self.id = 0 - self.num_questions = 0 - self.num_answers = 0 - self.num_authorities = 0 - self.num_additionals = 0 - self.valid = False - - try: - self.read_header() - self.read_questions() - self.read_others() - self.valid = True - - except (IndexError, struct.error, IncomingDecodeError): - self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'id=%s' % self.id, - 'flags=%s' % self.flags, - 'truncated=%s' % self.truncated, - 'n_q=%s' % self.num_questions, - 'n_ans=%s' % self.num_answers, - 'n_auth=%s' % self.num_authorities, - 'n_add=%s' % self.num_additionals, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - ] - ) - - def unpack(self, format_: bytes) -> tuple: - length = struct.calcsize(format_) - info = struct.unpack(format_, self.data[self.offset : self.offset + length]) - self.offset += length - return info - - def read_header(self) -> None: - """Reads header portion of packet""" - ( - self.id, - self.flags, - self.num_questions, - self.num_answers, - self.num_authorities, - self.num_additionals, - ) = self.unpack(b'!6H') - - def read_questions(self) -> None: - """Reads questions section of packet""" - for _ in range(self.num_questions): - name = self.read_name() - type_, class_ = self.unpack(b'!HH') - - question = DNSQuestion(name, type_, class_) - self.questions.append(question) - - # def read_int(self): - # """Reads an integer from the packet""" - # return self.unpack(b'!I')[0] - - def read_character_string(self) -> bytes: - """Reads a character string from the packet""" - length = self.data[self.offset] - self.offset += 1 - return self.read_string(length) - - def read_string(self, length: int) -> bytes: - """Reads a string of a given length from the packet""" - info = self.data[self.offset : self.offset + length] - self.offset += length - return info - - def read_unsigned_short(self) -> int: - """Reads an unsigned short from the packet""" - return cast(int, self.unpack(b'!H')[0]) - - def read_others(self) -> None: - """Reads the answers, authorities and additionals section of the - packet""" - n = self.num_answers + self.num_authorities + self.num_additionals - for _ in range(n): - domain = self.read_name() - type_, class_, ttl, length = self.unpack(b'!HHiH') - rec = None # type: Optional[DNSRecord] - if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) - elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) - elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) - elif type_ == _TYPE_SRV: - rec = DNSService( - domain, - type_, - class_, - ttl, - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_name(), - ) - elif type_ == _TYPE_HINFO: - rec = DNSHinfo( - domain, - type_, - class_, - ttl, - self.read_character_string().decode('utf-8'), - self.read_character_string().decode('utf-8'), - ) - elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) - else: - # Try to ignore types we don't know about - # Skip the payload for the resource record so the next - # records can be parsed correctly - self.offset += length - - if rec is not None: - self.answers.append(rec) - - def read_utf(self, offset: int, length: int) -> str: - """Reads a UTF-8 string of a given length from the packet""" - return str(self.data[offset : offset + length], 'utf-8', 'replace') - - def read_name(self) -> str: - """Reads a domain name from the packet""" - result = '' - off = self.offset - next_ = -1 - first = off - - while True: - length = self.data[off] - off += 1 - if length == 0: - break - t = length & 0xC0 - if t == 0x00: - result += self.read_utf(off, length) + '.' - off += length - elif t == 0xC0: - if next_ < 0: - next_ = off + 1 - off = ((length & 0x3F) << 8) | self.data[off] - if off >= first: - raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) - first = off - else: - raise IncomingDecodeError("Bad domain name at %s" % (off,)) - - if next_ >= 0: - self.offset = next_ - else: - self.offset = off - - return result - - -class DNSOutgoing(DNSMessage): - - """Object representation of an outgoing packet""" - - def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: - super().__init__(flags) - self.finished = False - self.id = id_ - self.multicast = multicast - self.packets_data: List[bytes] = [] - - # these 3 are per-packet -- see also _reset_for_next_packet() - self.names: Dict[str, int] = {} - self.data: List[bytes] = [] - self.size: int = _DNS_PACKET_HEADER_LEN - self.allow_long: bool = True - - self.state = self.State.init - - self.questions: List[DNSQuestion] = [] - self.answers: List[Tuple[DNSRecord, float]] = [] - self.authorities: List[DNSPointer] = [] - self.additionals: List[DNSRecord] = [] - - def _reset_for_next_packet(self) -> None: - self.names = {} - self.data = [] - self.size = _DNS_PACKET_HEADER_LEN - self.allow_long = True - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'multicast=%s' % self.multicast, - 'flags=%s' % self.flags, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - 'authorities=%s' % self.authorities, - 'additionals=%s' % self.additionals, - ] - ) - - class State(enum.Enum): - init = 0 - finished = 1 - - def add_question(self, record: DNSQuestion) -> None: - """Adds a question""" - self.questions.append(record) - - def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: - """Adds an answer""" - if not record.suppressed_by(inp): - self.add_answer_at_time(record, 0) - - def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: - """Adds an answer if it does not expire by a certain time""" - if record is not None and (now == 0 or not record.is_expired(now)): - self.answers.append((record, now)) - - def add_authorative_answer(self, record: DNSPointer) -> None: - """Adds an authoritative answer""" - self.authorities.append(record) - - def add_additional_answer(self, record: DNSRecord) -> None: - """Adds an additional answer - - From: RFC 6763, DNS-Based Service Discovery, February 2013 - - 12. DNS Additional Record Generation - - DNS has an efficiency feature whereby a DNS server may place - additional records in the additional section of the DNS message. - These additional records are records that the client did not - explicitly request, but the server has reasonable grounds to expect - that the client might request them shortly, so including them can - save the client from having to issue additional queries. - - This section recommends which additional records SHOULD be generated - to improve network efficiency, for both Unicast and Multicast DNS-SD - responses. - - 12.1. PTR Records - - When including a DNS-SD Service Instance Enumeration or Selective - Instance Enumeration (subtype) PTR record in a response packet, the - server/responder SHOULD include the following additional records: - - o The SRV record(s) named in the PTR rdata. - o The TXT record(s) named in the PTR rdata. - o All address records (type "A" and "AAAA") named in the SRV rdata. - - 12.2. SRV Records - - When including an SRV record in a response packet, the - server/responder SHOULD include the following additional records: - - o All address records (type "A" and "AAAA") named in the SRV rdata. - - """ - self.additionals.append(record) - - def add_question_or_one_cache( - self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int - ) -> None: - """Add a question if it is not already cached.""" - cached_entry = cache.get_by_details(name, type_, class_) - if not cached_entry: - self.add_question(DNSQuestion(name, type_, class_)) - else: - self.add_answer_at_time(cached_entry, now) - - def add_question_or_all_cache( - self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int - ) -> None: - """Add a question if it is not already cached. - This is currently only used for IPv6 addresses. - """ - cached_entries = cache.get_all_by_details(name, type_, class_) - if not cached_entries: - self.add_question(DNSQuestion(name, type_, class_)) - return - for cached_entry in cached_entries: - self.add_answer_at_time(cached_entry, now) - - def _pack(self, format_: Union[bytes, str], value: Any) -> None: - self.data.append(struct.pack(format_, value)) - self.size += struct.calcsize(format_) - - def _write_byte(self, value: int) -> None: - """Writes a single byte to the packet""" - self._pack(b'!c', int2byte(value)) - - def _insert_short_at_start(self, value: int) -> None: - """Inserts an unsigned short at the start of the packet""" - self.data.insert(0, struct.pack(b'!H', value)) - - def _replace_short(self, index: int, value: int) -> None: - """Replaces an unsigned short in a certain position in the packet""" - self.data[index] = struct.pack(b'!H', value) - - def write_short(self, value: int) -> None: - """Writes an unsigned short to the packet""" - self._pack(b'!H', value) - - def _write_int(self, value: Union[float, int]) -> None: - """Writes an unsigned integer to the packet""" - self._pack(b'!I', int(value)) - - def write_string(self, value: bytes) -> None: - """Writes a string to the packet""" - assert isinstance(value, bytes) - self.data.append(value) - self.size += len(value) - - def _write_utf(self, s: str) -> None: - """Writes a UTF-8 string of a given length to the packet""" - utfstr = s.encode('utf-8') - length = len(utfstr) - if length > 64: - raise NamePartTooLongException - self._write_byte(length) - self.write_string(utfstr) - - def write_character_string(self, value: bytes) -> None: - assert isinstance(value, bytes) - length = len(value) - if length > 256: - raise NamePartTooLongException - self._write_byte(length) - self.write_string(value) - - def write_name(self, name: str) -> None: - """ - Write names to packet - - 18.14. Name Compression - - When generating Multicast DNS messages, implementations SHOULD use - name compression wherever possible to compress the names of resource - records, by replacing some or all of the resource record name with a - compact two-byte reference to an appearance of that data somewhere - earlier in the message [RFC1035]. - """ - - # split name into each label - parts = name.split('.') - if not parts[-1]: - parts.pop() - - # construct each suffix - name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] - - # look for an existing name or suffix - for count, sub_name in enumerate(name_suffices): - if sub_name in self.names: - break - else: - count = len(name_suffices) - - # note the new names we are saving into the packet - name_length = len(name.encode('utf-8')) - for suffix in name_suffices[:count]: - self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 - - # write the new names out. - for part in parts[:count]: - self._write_utf(part) - - # if we wrote part of the name, create a pointer to the rest - if count != len(name_suffices): - # Found substring in packet, create pointer - index = self.names[name_suffices[count]] - self._write_byte((index >> 8) | 0xC0) - self._write_byte(index & 0xFF) - else: - # this is the end of a name - self._write_byte(0) - - def _write_question(self, question: DNSQuestion) -> bool: - """Writes a question to the packet""" - start_data_length, start_size = len(self.data), self.size - self.write_name(question.name) - self.write_short(question.type) - self._write_record_class(question) - return self._check_data_limit_or_rollback(start_data_length, start_size) - - def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: - """Write out the record class including the unique/unicast (QU) bit.""" - if record.unique and self.multicast: - self.write_short(record.class_ | _CLASS_UNIQUE) - else: - self.write_short(record.class_) - - def _write_ttl(self, record: DNSRecord, now: float) -> None: - """Write out the record ttl.""" - self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now)) - - def _write_record(self, record: DNSRecord, now: float) -> bool: - """Writes a record (answer, authoritative answer, additional) to - the packet. Returns True on success, or False if we did not - because the packet because the record does not fit.""" - start_data_length, start_size = len(self.data), self.size - self.write_name(record.name) - self.write_short(record.type) - self._write_record_class(record) - self._write_ttl(record, now) - index = len(self.data) - self.write_short(0) # Will get replaced with the actual size - record.write(self) - # Adjust size for the short we will write before this record - length = sum((len(d) for d in self.data[index + 1 :])) - # Here we replace the 0 length short we wrote - # before with the actual length - self._replace_short(index, length) - return self._check_data_limit_or_rollback(start_data_length, start_size) - - def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: - """Check data limit, if we go over, then rollback and return False.""" - len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL - self.allow_long = False - - if self.size <= len_limit: - return True - - log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) - del self.data[start_data_length:] - self.size = start_size - - rollback_names = [name for name, idx in self.names.items() if idx >= start_size] - for name in rollback_names: - del self.names[name] - return False - - def _write_questions_from_offset(self, questions_offset: int) -> int: - questions_written = 0 - for question in self.questions[questions_offset:]: - if not self._write_question(question): - break - questions_written += 1 - return questions_written - - def _write_answers_from_offset(self, answer_offset: int) -> int: - answers_written = 0 - for answer, time_ in self.answers[answer_offset:]: - if not self._write_record(answer, time_): - break - answers_written += 1 - return answers_written - - def _write_authorities_from_offset(self, authority_offset: int) -> int: - authorities_written = 0 - for authority in self.authorities[authority_offset:]: - if not self._write_record(authority, 0): - break - authorities_written += 1 - return authorities_written - - def _write_additionals_from_offset(self, additional_offset: int) -> int: - additionals_written = 0 - for additional in self.additionals[additional_offset:]: - if not self._write_record(additional, 0): - break - additionals_written += 1 - return additionals_written - - def _has_more_to_add( - self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int - ) -> bool: - """Check if all questions, answers, authority, and additionals have been written to the packet.""" - return ( - questions_offset < len(self.questions) - or answer_offset < len(self.answers) - or authority_offset < len(self.authorities) - or additional_offset < len(self.additionals) - ) - - def packets(self) -> List[bytes]: - """Returns a list of bytestrings containing the packets' bytes - - No further parts should be added to the packet once this - is done. The packets are each restricted to _MAX_MSG_TYPICAL - or less in length, except for the case of a single answer which - will be written out to a single oversized packet no more than - _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP - fragmentation potentially).""" - - if self.state == self.State.finished: - return self.packets_data - - questions_offset = 0 - answer_offset = 0 - authority_offset = 0 - additional_offset = 0 - # we have to at least write out the question - first_time = True - - while first_time or self._has_more_to_add( - questions_offset, answer_offset, authority_offset, additional_offset - ): - first_time = False - log.debug( - "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", - questions_offset, - answer_offset, - authority_offset, - additional_offset, - ) - log.debug( - "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", - len(self.questions), - len(self.answers), - len(self.authorities), - len(self.additionals), - ) - - questions_written = self._write_questions_from_offset(questions_offset) - answers_written = self._write_answers_from_offset(answer_offset) - authorities_written = self._write_authorities_from_offset(authority_offset) - additionals_written = self._write_additionals_from_offset(additional_offset) - - self._insert_short_at_start(additionals_written) - self._insert_short_at_start(authorities_written) - self._insert_short_at_start(answers_written) - self._insert_short_at_start(questions_written) - - questions_offset += questions_written - answer_offset += answers_written - authority_offset += authorities_written - additional_offset += additionals_written - log.debug( - "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", - questions_offset, - answer_offset, - authority_offset, - additional_offset, - ) - - if self.is_query() and self._has_more_to_add( - questions_offset, answer_offset, authority_offset, additional_offset - ): - # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 - log.debug("Setting TC flag") - self._insert_short_at_start(self.flags | _FLAGS_TC) - else: - self._insert_short_at_start(self.flags) - - if self.multicast: - self._insert_short_at_start(0) - else: - self._insert_short_at_start(self.id) - - self.packets_data.append(b''.join(self.data)) - self._reset_for_next_packet() - - if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( - len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) - ) > 0: - log.warning("packets() made no progress adding records; returning") - break - self.state = self.State.finished - return self.packets_data - - class DNSRRSet: """A set of dns records independent of the ttl.""" diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 042ab2a1..ad6f54fb 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -24,8 +24,9 @@ from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union from ._cache import DNSCache -from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord +from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log +from ._protocol import DNSIncoming, DNSOutgoing from ._services import RecordUpdateListener from ._services.registry import ServiceRegistry from ._utils.net import IPVersion diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py new file mode 100644 index 00000000..64c65b96 --- /dev/null +++ b/zeroconf/_protocol.py @@ -0,0 +1,644 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import enum +import struct +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast + +from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from ._exceptions import IncomingDecodeError, NamePartTooLongException +from ._logger import QuietLogger, log +from ._utils.struct import int2byte +from .const import ( + _CLASS_UNIQUE, + _DNS_PACKET_HEADER_LEN, + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, + _MAX_MSG_ABSOLUTE, + _MAX_MSG_TYPICAL, + _TYPE_A, + _TYPE_AAAA, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from ._cache import DNSCache # pylint: disable=cyclic-import + + +class DNSMessage: + """A base class for DNS messages.""" + + def __init__(self, flags: int) -> None: + """Construct a DNS message.""" + self.flags = flags + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + @property + def truncated(self) -> bool: + """Returns true if this is a truncated.""" + return (self.flags & _FLAGS_TC) == _FLAGS_TC + + +class DNSIncoming(DNSMessage, QuietLogger): + + """Object representation of an incoming DNS packet""" + + def __init__(self, data: bytes) -> None: + """Constructor from string holding bytes of packet""" + super().__init__(0) + self.offset = 0 + self.data = data + self.questions: List[DNSQuestion] = [] + self.answers: List[DNSRecord] = [] + self.id = 0 + self.num_questions = 0 + self.num_answers = 0 + self.num_authorities = 0 + self.num_additionals = 0 + self.valid = False + + try: + self.read_header() + self.read_questions() + self.read_others() + self.valid = True + + except (IndexError, struct.error, IncomingDecodeError): + self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'truncated=%s' % self.truncated, + 'n_q=%s' % self.num_questions, + 'n_ans=%s' % self.num_answers, + 'n_auth=%s' % self.num_authorities, + 'n_add=%s' % self.num_additionals, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + ] + ) + + def unpack(self, format_: bytes) -> tuple: + length = struct.calcsize(format_) + info = struct.unpack(format_, self.data[self.offset : self.offset + length]) + self.offset += length + return info + + def read_header(self) -> None: + """Reads header portion of packet""" + ( + self.id, + self.flags, + self.num_questions, + self.num_answers, + self.num_authorities, + self.num_additionals, + ) = self.unpack(b'!6H') + + def read_questions(self) -> None: + """Reads questions section of packet""" + for _ in range(self.num_questions): + name = self.read_name() + type_, class_ = self.unpack(b'!HH') + + question = DNSQuestion(name, type_, class_) + self.questions.append(question) + + def read_character_string(self) -> bytes: + """Reads a character string from the packet""" + length = self.data[self.offset] + self.offset += 1 + return self.read_string(length) + + def read_string(self, length: int) -> bytes: + """Reads a string of a given length from the packet""" + info = self.data[self.offset : self.offset + length] + self.offset += length + return info + + def read_unsigned_short(self) -> int: + """Reads an unsigned short from the packet""" + return cast(int, self.unpack(b'!H')[0]) + + def read_others(self) -> None: + """Reads the answers, authorities and additionals section of the + packet""" + n = self.num_answers + self.num_authorities + self.num_additionals + for _ in range(n): + domain = self.read_name() + type_, class_, ttl, length = self.unpack(b'!HHiH') + rec: Optional[DNSRecord] = None + if type_ == _TYPE_A: + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) + elif type_ in (_TYPE_CNAME, _TYPE_PTR): + rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) + elif type_ == _TYPE_TXT: + rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) + elif type_ == _TYPE_SRV: + rec = DNSService( + domain, + type_, + class_, + ttl, + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_name(), + ) + elif type_ == _TYPE_HINFO: + rec = DNSHinfo( + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), + ) + elif type_ == _TYPE_AAAA: + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) + else: + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + + if rec is not None: + self.answers.append(rec) + + def read_utf(self, offset: int, length: int) -> str: + """Reads a UTF-8 string of a given length from the packet""" + return str(self.data[offset : offset + length], 'utf-8', 'replace') + + def read_name(self) -> str: + """Reads a domain name from the packet""" + result = '' + off = self.offset + next_ = -1 + first = off + + while True: + length = self.data[off] + off += 1 + if length == 0: + break + t = length & 0xC0 + if t == 0x00: + result += self.read_utf(off, length) + '.' + off += length + elif t == 0xC0: + if next_ < 0: + next_ = off + 1 + off = ((length & 0x3F) << 8) | self.data[off] + if off >= first: + raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) + first = off + else: + raise IncomingDecodeError("Bad domain name at %s" % (off,)) + + if next_ >= 0: + self.offset = next_ + else: + self.offset = off + + return result + + +class DNSOutgoing(DNSMessage): + + """Object representation of an outgoing packet""" + + def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: + super().__init__(flags) + self.finished = False + self.id = id_ + self.multicast = multicast + self.packets_data: List[bytes] = [] + + # these 3 are per-packet -- see also _reset_for_next_packet() + self.names: Dict[str, int] = {} + self.data: List[bytes] = [] + self.size: int = _DNS_PACKET_HEADER_LEN + self.allow_long: bool = True + + self.state = self.State.init + + self.questions: List[DNSQuestion] = [] + self.answers: List[Tuple[DNSRecord, float]] = [] + self.authorities: List[DNSPointer] = [] + self.additionals: List[DNSRecord] = [] + + def _reset_for_next_packet(self) -> None: + self.names = {} + self.data = [] + self.size = _DNS_PACKET_HEADER_LEN + self.allow_long = True + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'multicast=%s' % self.multicast, + 'flags=%s' % self.flags, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + 'authorities=%s' % self.authorities, + 'additionals=%s' % self.additionals, + ] + ) + + class State(enum.Enum): + init = 0 + finished = 1 + + def add_question(self, record: DNSQuestion) -> None: + """Adds a question""" + self.questions.append(record) + + def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None: + """Adds an answer""" + if not record.suppressed_by(inp): + self.add_answer_at_time(record, 0) + + def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None: + """Adds an answer if it does not expire by a certain time""" + if record is not None and (now == 0 or not record.is_expired(now)): + self.answers.append((record, now)) + + def add_authorative_answer(self, record: DNSPointer) -> None: + """Adds an authoritative answer""" + self.authorities.append(record) + + def add_additional_answer(self, record: DNSRecord) -> None: + """Adds an additional answer + + From: RFC 6763, DNS-Based Service Discovery, February 2013 + + 12. DNS Additional Record Generation + + DNS has an efficiency feature whereby a DNS server may place + additional records in the additional section of the DNS message. + These additional records are records that the client did not + explicitly request, but the server has reasonable grounds to expect + that the client might request them shortly, so including them can + save the client from having to issue additional queries. + + This section recommends which additional records SHOULD be generated + to improve network efficiency, for both Unicast and Multicast DNS-SD + responses. + + 12.1. PTR Records + + When including a DNS-SD Service Instance Enumeration or Selective + Instance Enumeration (subtype) PTR record in a response packet, the + server/responder SHOULD include the following additional records: + + o The SRV record(s) named in the PTR rdata. + o The TXT record(s) named in the PTR rdata. + o All address records (type "A" and "AAAA") named in the SRV rdata. + + 12.2. SRV Records + + When including an SRV record in a response packet, the + server/responder SHOULD include the following additional records: + + o All address records (type "A" and "AAAA") named in the SRV rdata. + + """ + self.additionals.append(record) + + def add_question_or_one_cache( + self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached.""" + cached_entry = cache.get_by_details(name, type_, class_) + if not cached_entry: + self.add_question(DNSQuestion(name, type_, class_)) + else: + self.add_answer_at_time(cached_entry, now) + + def add_question_or_all_cache( + self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + ) -> None: + """Add a question if it is not already cached. + This is currently only used for IPv6 addresses. + """ + cached_entries = cache.get_all_by_details(name, type_, class_) + if not cached_entries: + self.add_question(DNSQuestion(name, type_, class_)) + return + for cached_entry in cached_entries: + self.add_answer_at_time(cached_entry, now) + + def _pack(self, format_: Union[bytes, str], value: Any) -> None: + self.data.append(struct.pack(format_, value)) + self.size += struct.calcsize(format_) + + def _write_byte(self, value: int) -> None: + """Writes a single byte to the packet""" + self._pack(b'!c', int2byte(value)) + + def _insert_short_at_start(self, value: int) -> None: + """Inserts an unsigned short at the start of the packet""" + self.data.insert(0, struct.pack(b'!H', value)) + + def _replace_short(self, index: int, value: int) -> None: + """Replaces an unsigned short in a certain position in the packet""" + self.data[index] = struct.pack(b'!H', value) + + def write_short(self, value: int) -> None: + """Writes an unsigned short to the packet""" + self._pack(b'!H', value) + + def _write_int(self, value: Union[float, int]) -> None: + """Writes an unsigned integer to the packet""" + self._pack(b'!I', int(value)) + + def write_string(self, value: bytes) -> None: + """Writes a string to the packet""" + assert isinstance(value, bytes) + self.data.append(value) + self.size += len(value) + + def _write_utf(self, s: str) -> None: + """Writes a UTF-8 string of a given length to the packet""" + utfstr = s.encode('utf-8') + length = len(utfstr) + if length > 64: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(utfstr) + + def write_character_string(self, value: bytes) -> None: + assert isinstance(value, bytes) + length = len(value) + if length > 256: + raise NamePartTooLongException + self._write_byte(length) + self.write_string(value) + + def write_name(self, name: str) -> None: + """ + Write names to packet + + 18.14. Name Compression + + When generating Multicast DNS messages, implementations SHOULD use + name compression wherever possible to compress the names of resource + records, by replacing some or all of the resource record name with a + compact two-byte reference to an appearance of that data somewhere + earlier in the message [RFC1035]. + """ + + # split name into each label + parts = name.split('.') + if not parts[-1]: + parts.pop() + + # construct each suffix + name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] + + # look for an existing name or suffix + for count, sub_name in enumerate(name_suffices): + if sub_name in self.names: + break + else: + count = len(name_suffices) + + # note the new names we are saving into the packet + name_length = len(name.encode('utf-8')) + for suffix in name_suffices[:count]: + self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 + + # write the new names out. + for part in parts[:count]: + self._write_utf(part) + + # if we wrote part of the name, create a pointer to the rest + if count != len(name_suffices): + # Found substring in packet, create pointer + index = self.names[name_suffices[count]] + self._write_byte((index >> 8) | 0xC0) + self._write_byte(index & 0xFF) + else: + # this is the end of a name + self._write_byte(0) + + def _write_question(self, question: DNSQuestion) -> bool: + """Writes a question to the packet""" + start_data_length, start_size = len(self.data), self.size + self.write_name(question.name) + self.write_short(question.type) + self._write_record_class(question) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _write_record_class(self, record: Union[DNSQuestion, DNSRecord]) -> None: + """Write out the record class including the unique/unicast (QU) bit.""" + if record.unique and self.multicast: + self.write_short(record.class_ | _CLASS_UNIQUE) + else: + self.write_short(record.class_) + + def _write_ttl(self, record: DNSRecord, now: float) -> None: + """Write out the record ttl.""" + self._write_int(record.ttl if now == 0 else record.get_remaining_ttl(now)) + + def _write_record(self, record: DNSRecord, now: float) -> bool: + """Writes a record (answer, authoritative answer, additional) to + the packet. Returns True on success, or False if we did not + because the packet because the record does not fit.""" + start_data_length, start_size = len(self.data), self.size + self.write_name(record.name) + self.write_short(record.type) + self._write_record_class(record) + self._write_ttl(record, now) + index = len(self.data) + self.write_short(0) # Will get replaced with the actual size + record.write(self) + # Adjust size for the short we will write before this record + length = sum((len(d) for d in self.data[index + 1 :])) + # Here we replace the 0 length short we wrote + # before with the actual length + self._replace_short(index, length) + return self._check_data_limit_or_rollback(start_data_length, start_size) + + def _check_data_limit_or_rollback(self, start_data_length: int, start_size: int) -> bool: + """Check data limit, if we go over, then rollback and return False.""" + len_limit = _MAX_MSG_ABSOLUTE if self.allow_long else _MAX_MSG_TYPICAL + self.allow_long = False + + if self.size <= len_limit: + return True + + log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit) + del self.data[start_data_length:] + self.size = start_size + + rollback_names = [name for name, idx in self.names.items() if idx >= start_size] + for name in rollback_names: + del self.names[name] + return False + + def _write_questions_from_offset(self, questions_offset: int) -> int: + questions_written = 0 + for question in self.questions[questions_offset:]: + if not self._write_question(question): + break + questions_written += 1 + return questions_written + + def _write_answers_from_offset(self, answer_offset: int) -> int: + answers_written = 0 + for answer, time_ in self.answers[answer_offset:]: + if not self._write_record(answer, time_): + break + answers_written += 1 + return answers_written + + def _write_authorities_from_offset(self, authority_offset: int) -> int: + authorities_written = 0 + for authority in self.authorities[authority_offset:]: + if not self._write_record(authority, 0): + break + authorities_written += 1 + return authorities_written + + def _write_additionals_from_offset(self, additional_offset: int) -> int: + additionals_written = 0 + for additional in self.additionals[additional_offset:]: + if not self._write_record(additional, 0): + break + additionals_written += 1 + return additionals_written + + def _has_more_to_add( + self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int + ) -> bool: + """Check if all questions, answers, authority, and additionals have been written to the packet.""" + return ( + questions_offset < len(self.questions) + or answer_offset < len(self.answers) + or authority_offset < len(self.authorities) + or additional_offset < len(self.additionals) + ) + + def packets(self) -> List[bytes]: + """Returns a list of bytestrings containing the packets' bytes + + No further parts should be added to the packet once this + is done. The packets are each restricted to _MAX_MSG_TYPICAL + or less in length, except for the case of a single answer which + will be written out to a single oversized packet no more than + _MAX_MSG_ABSOLUTE in length (and hence will be subject to IP + fragmentation potentially).""" + + if self.state == self.State.finished: + return self.packets_data + + questions_offset = 0 + answer_offset = 0 + authority_offset = 0 + additional_offset = 0 + # we have to at least write out the question + first_time = True + + while first_time or self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + first_time = False + log.debug( + "offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + log.debug( + "lengths = questions=%d, answers=%d, authorities=%d, additionals=%d", + len(self.questions), + len(self.answers), + len(self.authorities), + len(self.additionals), + ) + + questions_written = self._write_questions_from_offset(questions_offset) + answers_written = self._write_answers_from_offset(answer_offset) + authorities_written = self._write_authorities_from_offset(authority_offset) + additionals_written = self._write_additionals_from_offset(additional_offset) + + self._insert_short_at_start(additionals_written) + self._insert_short_at_start(authorities_written) + self._insert_short_at_start(answers_written) + self._insert_short_at_start(questions_written) + + questions_offset += questions_written + answer_offset += answers_written + authority_offset += authorities_written + additional_offset += additionals_written + log.debug( + "now offsets = questions=%d, answers=%d, authorities=%d, additionals=%d", + questions_offset, + answer_offset, + authority_offset, + additional_offset, + ) + + if self.is_query() and self._has_more_to_add( + questions_offset, answer_offset, authority_offset, additional_offset + ): + # https://datatracker.ietf.org/doc/html/rfc6762#section-7.2 + log.debug("Setting TC flag") + self._insert_short_at_start(self.flags | _FLAGS_TC) + else: + self._insert_short_at_start(self.flags) + + if self.multicast: + self._insert_short_at_start(0) + else: + self._insert_short_at_start(self.id) + + self.packets_data.append(b''.join(self.data)) + self._reset_for_next_packet() + + if (questions_written + answers_written + authorities_written + additionals_written) == 0 and ( + len(self.questions) + len(self.answers) + len(self.authorities) + len(self.additionals) + ) > 0: + log.warning("packets() made no progress adding records; returning") + break + self.state = self.State.finished + return self.packets_data diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 818b3bb6..70558c82 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -27,8 +27,9 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from .._dns import DNSAddress, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException +from .._protocol import DNSOutgoing from .._utils.name import service_type_name from .._utils.net import ( IPVersion, From dc0c6137742edf97626c972e5c9191dfbffaecdc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 10:42:38 -1000 Subject: [PATCH 358/608] Fix thread safety in _ServiceBrowser.update_records_complete (#708) --- zeroconf/_services/__init__.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 70558c82..b6481eb3 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -369,8 +369,15 @@ def update_records_complete(self) -> None: At this point the cache will have the new records. """ - self._handlers_to_call.update(self._pending_handlers) - self._pending_handlers.clear() + # Cannot use .update here since PyPy can fail with + # RuntimeError: dictionary changed size during iteration + # for threaded ServiceBrowsers + while self._pending_handlers: + try: + (name_type, state_change) = self._pending_handlers.popitem(False) + except KeyError: + return + self._handlers_to_call[name_type] = state_change def cancel(self) -> None: """Cancel the browser.""" From f3eeecd84413b510b9b8e05e2d1f6ad99d0dc37d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 10:42:50 -1000 Subject: [PATCH 359/608] Set stale unique records to expire 1s in the future instead of instant removal (#706) - Fixes #475 - https://tools.ietf.org/html/rfc6762#section-10.2 Queriers receiving a Multicast DNS response with a TTL of zero SHOULD NOT immediately delete the record from the cache, but instead record a TTL of 1 and then delete the record one second later. In the case of multiple Multicast DNS responders on the network described in Section 6.6 above, if one of the responders shuts down and incorrectly sends goodbye packets for its records, it gives the other cooperating responders one second to send out their own response to "rescue" the records before they expire and are deleted. --- tests/test_core.py | 4 ++-- tests/test_handlers.py | 4 ++-- zeroconf/_dns.py | 4 ++-- zeroconf/_handlers.py | 3 ++- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 4d001208..f8577a6b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -17,7 +17,7 @@ from typing import cast import zeroconf as r -from zeroconf import _core, const, ServiceBrowser, Zeroconf +from zeroconf import _core, const, ServiceBrowser, Zeroconf, current_time_millis from . import has_working_ipv6, _clear_cache, _inject_response @@ -241,7 +241,7 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS # service removed _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) dns_text = zeroconf.cache.get_by_details(service_name, const._TYPE_TXT, const._CLASS_IN) - assert dns_text is None + assert dns_text.is_expired(current_time_millis() + 1000) finally: zeroconf.close() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 11c077b4..50256361 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -735,7 +735,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # Add the A record to the cache with 50% ttl remaining a_record = info.dns_addresses()[0] - a_record._set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl) + a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl) assert not a_record.is_recent(current_time_millis()) zc.cache.add(a_record) @@ -776,7 +776,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # Remove the 100% PTR record and add a 50% PTR record zc.cache.remove(ptr_record) - ptr_record._set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl) + ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl) assert not ptr_record.is_recent(current_time_millis()) zc.cache.add(ptr_record) # With QU should respond to only multicast since the has less diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 9a6ef73d..b5b2bb79 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -194,9 +194,9 @@ def is_recent(self, now: float) -> bool: def reset_ttl(self, other: 'DNSRecord') -> None: """Sets this record's TTL and created time to that of another record.""" - self._set_created_ttl(other.created, other.ttl) + self.set_created_ttl(other.created, other.ttl) - def _set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: + def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: """Set the created and ttl of a record.""" self.created = created self.ttl = ttl diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index ad6f54fb..f8e9c6df 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -298,7 +298,8 @@ def updates_from_response(self, msg: DNSIncoming) -> None: if entry == record: updated = False if record.created - entry.created > 1000 and entry not in msg.answers: - removes.append(entry) + # Expire in 1s + entry.set_created_ttl(now, 1) expired = record.is_expired(now) maybe_entry = self.cache.get(record) From c366c8cc45f565c4066fc72b481c6a960bac1cb9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 11:13:53 -1000 Subject: [PATCH 360/608] Synchronize created time for incoming and outgoing queries (#709) --- tests/test_protocol.py | 26 ++++++++++++++++++++ zeroconf/_core.py | 11 +++++---- zeroconf/_dns.py | 33 ++++++++++++++++--------- zeroconf/_handlers.py | 44 ++++++++++++++++++++++------------ zeroconf/_protocol.py | 13 ++++++---- zeroconf/_services/__init__.py | 17 +++++++++---- 6 files changed, 104 insertions(+), 40 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8b4ebd04..ebdb7110 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -720,3 +720,29 @@ def test_qu_packet_parser(): parsed = DNSIncoming(qu_packet) assert parsed.questions[0].unicast is True assert ",QU," in str(parsed.questions[0]) + + +def test_records_same_packet_share_fate(): + """Test records in the same packet all have the same created time.""" + out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) + type_ = "_hap._tcp.local." + out.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)) + + for i in range(30): + out.add_answer_at_time( + DNSText( + ("HASS Bridge W9DN %s._hap._tcp.local." % i), + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b'\x13md=HASS Bridge W9DN\x06pv=1.0\x14id=11:8E:DB:5B:5C:C5\x05c#=12\x04s#=1' + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + 0, + ) + + for packet in out.packets(): + dnsin = DNSIncoming(packet) + first_time = dnsin.answers[0].created + for answer in dnsin.answers: + assert answer.created == first_time diff --git a/zeroconf/_core.py b/zeroconf/_core.py index f4fb647a..12f3c5f4 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -465,19 +465,20 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d # # _CLASS_UNIQUE is the "QU" bit out.add_question(DNSQuestion(info.type, _TYPE_PTR, _CLASS_IN | _CLASS_UNIQUE)) - out.add_authorative_answer(info.dns_pointer()) + out.add_authorative_answer(info.dns_pointer(created=current_time_millis())) return out def _add_broadcast_answer( # pylint: disable=no-self-use self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] ) -> None: """Add answers to broadcast a service.""" + now = current_time_millis() other_ttl = info.other_ttl if override_ttl is None else override_ttl host_ttl = info.host_ttl if override_ttl is None else override_ttl - out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl), 0) - out.add_answer_at_time(info.dns_service(override_ttl=host_ttl), 0) - out.add_answer_at_time(info.dns_text(override_ttl=other_ttl), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl): + out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0) + out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0) + out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0) + for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index b5b2bb79..c6d0108e 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -139,10 +139,12 @@ class DNSRecord(DNSEntry): """A DNS record - like a DNS entry, but has a TTL""" # TODO: Switch to just int ttl - def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None + ) -> None: super().__init__(name, type_, class_) self.ttl = ttl - self.created = current_time_millis() + self.created = created or current_time_millis() self._expiration_time: Optional[float] = None self._stale_time: Optional[float] = None self._recent_time: Optional[float] = None @@ -218,8 +220,10 @@ class DNSAddress(DNSRecord): """A DNS address record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.address = address def write(self, out: 'DNSOutgoing') -> None: @@ -252,8 +256,10 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.cpu = cpu self.os = os @@ -284,8 +290,10 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None: - super().__init__(name, type_, class_, ttl) + def __init__( + self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None + ) -> None: + super().__init__(name, type_, class_, ttl, created) self.alias = alias @property @@ -319,9 +327,11 @@ class DNSText(DNSRecord): """A DNS text record""" - def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None: + def __init__( + self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None + ) -> None: assert isinstance(text, (bytes, type(None))) - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.text = text def write(self, out: 'DNSOutgoing') -> None: @@ -357,8 +367,9 @@ def __init__( weight: int, port: int, server: str, + created: Optional[float] = None, ) -> None: - super().__init__(name, type_, class_, ttl) + super().__init__(name, type_, class_, ttl, created) self.priority = priority self.weight = weight self.port = port diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f8e9c6df..1d6cac4c 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -162,7 +162,7 @@ def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: self.cache = cache def _add_service_type_enumeration_query_answers( - self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Provide an answer to a service type enumeration query. @@ -170,47 +170,60 @@ def _add_service_type_enumeration_query_answers( """ for stype in self.registry.get_types(): dns_pointer = DNSPointer( - _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype + _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now ) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set() def _add_pointer_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" for service in self.registry.get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. - dns_pointer = service.dns_pointer() + dns_pointer = service.dns_pointer(created=now) if not known_answers.suppresses(dns_pointer): answer_set[dns_pointer] = set( - [service.dns_service(), service.dns_text(), *service.dns_addresses()] + [ + service.dns_service(created=now), + service.dns_text(created=now), + *service.dns_addresses(created=now), + ] ) def _add_address_answers( - self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, type_: int + self, + name: str, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, + type_: int, ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): - for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]): + for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_], created=now): if not known_answers.suppresses(dns_address): answer_set[dns_address] = set() def _answer_question( - self, question: DNSQuestion, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet + self, + question: DNSQuestion, + answer_set: _AnswerWithAdditionalsType, + known_answers: DNSRRSet, + now: float, ) -> None: if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: - self._add_service_type_enumeration_query_answers(answer_set, known_answers) + self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) return type_ = question.type if type_ in (_TYPE_PTR, _TYPE_ANY): - self._add_pointer_answers(question.name, answer_set, known_answers) + self._add_pointer_answers(question.name, answer_set, known_answers, now) if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY): - self._add_address_answers(question.name, answer_set, known_answers, type_) + self._add_address_answers(question.name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): service = self.registry.get_info_name(question.name) # type: ignore @@ -218,11 +231,11 @@ def _answer_question( if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.2. - dns_service = service.dns_service() + dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = set(service.dns_addresses()) + answer_set[dns_service] = set(service.dns_addresses(created=now)) if type_ in (_TYPE_TXT, _TYPE_ANY): - dns_text = service.dns_text() + dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text): answer_set[dns_text] = set() @@ -233,10 +246,11 @@ def response( # pylint: disable=unused-argument ucast_source = port != _MDNS_PORT known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) + now = current_time_millis() for question in itertools.chain(*[msg.questions for msg in msgs]): answer_set: _AnswerWithAdditionalsType = {} - self._answer_question(question, answer_set, known_answers) + self._answer_question(question, answer_set, known_answers, now) if not ucast_source and question.unicast: query_res.add_qu_question_response(answer_set) else: diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 64c65b96..80ca7b88 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -24,10 +24,12 @@ import struct from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast + from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ._exceptions import IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log from ._utils.struct import int2byte +from ._utils.time import current_time_millis from .const import ( _CLASS_UNIQUE, _DNS_PACKET_HEADER_LEN, @@ -90,6 +92,7 @@ def __init__(self, data: bytes) -> None: self.num_authorities = 0 self.num_additionals = 0 self.valid = False + self.now = current_time_millis() try: self.read_header() @@ -166,11 +169,11 @@ def read_others(self) -> None: type_, class_, ttl, length = self.unpack(b'!HHiH') rec: Optional[DNSRecord] = None if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), self.now) elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name()) + rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length)) + rec = DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) elif type_ == _TYPE_SRV: rec = DNSService( domain, @@ -181,6 +184,7 @@ def read_others(self) -> None: self.read_unsigned_short(), self.read_unsigned_short(), self.read_name(), + self.now, ) elif type_ == _TYPE_HINFO: rec = DNSHinfo( @@ -190,9 +194,10 @@ def read_others(self) -> None: ttl, self.read_character_string().decode('utf-8'), self.read_character_string().decode('utf-8'), + self.now, ) elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16)) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16), self.now) else: # Try to ignore types we don't know about # Skip the payload for the resource record so the next diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index b6481eb3..c6efcc8e 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -369,7 +369,7 @@ def update_records_complete(self) -> None: At this point the cache will have the new records. """ - # Cannot use .update here since PyPy can fail with + # Cannot use .update here since can fail with # RuntimeError: dictionary changed size during iteration # for threaded ServiceBrowsers while self._pending_handlers: @@ -722,7 +722,10 @@ def _process_record(self, record: DNSRecord, now: float) -> None: self._set_text(record.text) def dns_addresses( - self, override_ttl: Optional[int] = None, version: IPVersion = IPVersion.All + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, ) -> List[DNSAddress]: """Return matching DNSAddress from ServiceInfo.""" return [ @@ -732,11 +735,12 @@ def dns_addresses( _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, address, + created, ) for address in self.addresses_by_version(version) ] - def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: """Return DNSPointer from ServiceInfo.""" return DNSPointer( self.type, @@ -744,9 +748,10 @@ def dns_pointer(self, override_ttl: Optional[int] = None) -> DNSPointer: _CLASS_IN, override_ttl if override_ttl is not None else self.other_ttl, self.name, + created, ) - def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: """Return DNSService from ServiceInfo.""" return DNSService( self.name, @@ -757,9 +762,10 @@ def dns_service(self, override_ttl: Optional[int] = None) -> DNSService: self.weight, cast(int, self.port), self.server, + created, ) - def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: """Return DNSText from ServiceInfo.""" return DNSText( self.name, @@ -767,6 +773,7 @@ def dns_text(self, override_ttl: Optional[int] = None) -> DNSText: _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.other_ttl, self.text, + created, ) def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: From aeb1b23defa2d5956a6f19acca4ce410d6a04cc9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 11:15:03 -1000 Subject: [PATCH 361/608] Add setter for DNSQuestion to easily make a QU question (#710) Closes #703 --- tests/test_handlers.py | 18 +++++++++--------- zeroconf/_dns.py | 5 +++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 50256361..c62f6a11 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -415,7 +415,7 @@ def _validate_complete_response(query, out): # With QU should respond to only unicast when the answer has been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) @@ -429,7 +429,7 @@ def _validate_complete_response(query, out): # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) unicast_out, multicast_out = zc.query_handler.response( @@ -441,7 +441,7 @@ def _validate_complete_response(query, out): # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) query.add_authorative_answer(info2.dns_pointer()) @@ -455,7 +455,7 @@ def _validate_complete_response(query, out): # With the cache repopulated; should respond to only unicast when the answer has been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) unicast_out, multicast_out = zc.query_handler.response( @@ -743,7 +743,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # even if the additional has not been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) @@ -763,7 +763,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # even if the additional has not been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) @@ -783,7 +783,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # than 75% of its ttl remaining query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) @@ -803,12 +803,12 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # than 75% of its ttl remaining query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) question = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) - question.unique = True # Set the QU bit + question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index c6d0108e..5b7fe70f 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -124,6 +124,11 @@ def unicast(self) -> bool: """ return self.unique + @unicast.setter + def unicast(self, value: bool) -> None: + """Sets the QU bit (not QM).""" + self.unique = value + def __repr__(self) -> str: """String representation""" return "%s[question,%s,%s,%s]" % ( From 6b923deb3682088d0fe9182377b5603d0ade1e1a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:03:45 -1000 Subject: [PATCH 362/608] Cleanup typing in zeroconf._services.registry (#712) --- zeroconf/_services/registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 8ec34120..20584b3a 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -40,9 +40,9 @@ def __init__( self, ) -> None: """Create the ServiceRegistry class.""" - self._services = {} # type: Dict[str, ServiceInfo] - self.types = {} # type: Dict[str, List] - self.servers = {} # type: Dict[str, List] + self._services: Dict[str, ServiceInfo] = {} + self.types: Dict[str, List] = {} + self.servers: Dict[str, List] = {} self._lock = threading.Lock() # add and remove services thread safe def add(self, info: ServiceInfo) -> None: From a42512ca6a6a4c15f37ab623a96deb2aa06dd053 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:04:29 -1000 Subject: [PATCH 363/608] Cleanup typing in zeroconf._services (#711) --- zeroconf/_services/__init__.py | 63 +++++++++++++++++----------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index c6efcc8e..c71aaed4 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -99,7 +99,7 @@ def update_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: class Signal: def __init__(self) -> None: - self._handlers = [] # type: List[Callable[..., None]] + self._handlers: List[Callable[..., None]] = [] def fire(self, **kwargs: Any) -> None: for h in list(self._handlers): @@ -233,7 +233,7 @@ def __init__( ) -> None: """Creates a browser for a specific type""" assert handlers or listener, 'You need to specify at least one handler' - self.types = set(type_ if isinstance(type_, list) else [type_]) # type: Set[str] + self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) for check_type_ in self.types: # Will generate BadTypeInNameException on a bad name service_type_name(check_type_, strict=False) @@ -245,9 +245,8 @@ def __init__( current_time = current_time_millis() self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} - self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] - self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange] - + self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() + self._handlers_to_call: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._service_state_changed = Signal() self.done = False @@ -552,13 +551,13 @@ def __init__( self.port = port self.weight = weight self.priority = priority - if server: - self.server = server - else: - self.server = name + self.server = server if server else name self.server_key = self.server.lower() - self._properties = {} # type: Dict - self._set_properties(properties) + self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + if isinstance(properties, bytes): + self._set_text(properties) + else: + self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl @@ -618,33 +617,33 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: for addr in result ] - def _set_properties(self, properties: Union[bytes, Dict]) -> None: + def _set_properties(self, properties: Dict) -> None: """Sets properties and text of this info from a dictionary""" - if isinstance(properties, dict): - self._properties = properties - list_ = [] - result = b'' - for key, value in properties.items(): - if isinstance(key, str): - key = key.encode('utf-8') - - record = key - if value is not None: - if not isinstance(value, bytes): - value = str(value).encode('utf-8') - record += b'=' + value - list_.append(record) - for item in list_: - result = b''.join((result, int2byte(len(item)), item)) - self.text = result - else: - self.text = properties + self._properties = properties + list_ = [] + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, int2byte(len(item)), item)) + self.text = result def _set_text(self, text: bytes) -> None: """Sets properties and text given a text field""" self.text = text - result = {} # type: Dict end = len(text) + if end == 0: + self._properties = {} + return + result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} index = 0 strs = [] while index < end: From a50b3eeda5f275c31b36cdc1c8312f61599e72bf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:04:38 -1000 Subject: [PATCH 364/608] Cleanup typing in zeroconf._utils.net (#713) --- zeroconf/_utils/net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 8d1b60bc..19500e0f 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -135,7 +135,7 @@ def normalize_interface_choice( :param ip_address: IP version to use (ignored if `choice` is a list). :returns: List of IP addresses (for IPv4) and indexes (for IPv6). """ - result = [] # type: List[Union[str, Tuple[Tuple[str, int, int], int]]] + result: List[Union[str, Tuple[Tuple[str, int, int], int]]] = [] if choice is InterfaceChoice.Default: if ip_version != IPVersion.V4Only: # IPv6 multicast uses interface 0 to mean the default From 3fcdcfd9a3efc56a34f0334ffb8706613e07d19d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:15:22 -1000 Subject: [PATCH 365/608] Cleanup typing in zeroconf._logger (#715) --- zeroconf/_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py index b7cb745a..3577bb05 100644 --- a/zeroconf/_logger.py +++ b/zeroconf/_logger.py @@ -32,7 +32,7 @@ class QuietLogger: - _seen_logs = {} # type: Dict[str, Union[int, tuple]] + _seen_logs: Dict[str, Union[int, tuple]] = {} @classmethod def log_exception_warning(cls, *logger_data: Any) -> None: From 0f2f4e207cb5007112ba09e87a332b1a46cd1577 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:15:32 -1000 Subject: [PATCH 366/608] Update README (#716) --- README.rst | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index eb730cdf..b6793690 100644 --- a/README.rst +++ b/README.rst @@ -37,7 +37,7 @@ Compared to some other Zeroconf/Bonjour/Avahi Python packages, python-zeroconf: * isn't tied to Bonjour or Avahi * doesn't use D-Bus -* doesn't force you to use particular event loop or Twisted +* doesn't force you to use particular event loop or Twisted (asyncio is used under the hood but not required) * is pip-installable * has PyPI distribution @@ -59,8 +59,14 @@ This project's versions follow the following pattern: MAJOR.MINOR.PATCH. Status ------ -There are some people using this package. I don't actively use it and as such -any help I can offer with regard to any issues is very limited. +This project is actively maintained. + +Traffic Reduction +----------------- + +Before version 0.32, most traffic reduction techniques described in https://datatracker.ietf.org/doc/html/rfc6762#section-7 +where not implemented which could lead to excessive network traffic. It is highly recommended that version 0.32 or later +is used if this is a concern. IPv6 support ------------ From 818364008e911757fca24e41a4eb36e0eef49bfa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:34:30 -1000 Subject: [PATCH 367/608] Cleanup typing in zero._core and document ignores (#714) --- zeroconf/_core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 12f3c5f4..553b349d 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -190,7 +190,7 @@ class AsyncListener(asyncio.Protocol, QuietLogger): def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc - self.data = None # type: Optional[bytes] + self.data: Optional[bytes] = None self.transport: Optional[asyncio.DatagramTransport] = None super().__init__() @@ -199,8 +199,10 @@ def datagram_received( ) -> None: assert self.transport is not None if len(addrs) == 2: + # https://github.com/python/mypy/issues/1178 addr, port = addrs # type: ignore elif len(addrs) == 4: + # https://github.com/python/mypy/issues/1178 addr, port, _flow, _scope = addrs # type: ignore else: return From 1ab685960bc0e412d36baf6794fde06350998474 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 12:34:38 -1000 Subject: [PATCH 368/608] Update changelog (#717) --- README.rst | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/README.rst b/README.rst index b6793690..5e5c7ea7 100644 --- a/README.rst +++ b/README.rst @@ -231,6 +231,34 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Cleanup typing in zero._core and document ignores (#714) @bdraco + +* Cleanup typing in zeroconf._logger (#715) @bdraco + +* Cleanup typing in zeroconf._utils.net (#713) @bdraco + +* Cleanup typing in zeroconf._services (#711) @bdraco + +* Cleanup typing in zeroconf._services.registry (#712) @bdraco + +* Add setter for DNSQuestion to easily make a QU question (#710) @bdraco + +* Set stale unique records to expire 1s in the future instead of instant removal (#706) @bdraco + + tools.ietf.org/html/rfc6762#section-10.2 + Queriers receiving a Multicast DNS response with a TTL of zero SHOULD + NOT immediately delete the record from the cache, but instead record + a TTL of 1 and then delete the record one second later. In the case + of multiple Multicast DNS responders on the network described in + Section 6.6 above, if one of the responders shuts down and + incorrectly sends goodbye packets for its records, it gives the other + cooperating responders one second to send out their own response to + "rescue" the records before they expire and are deleted. + +* Fix thread safety in _ServiceBrowser.update_records_complete (#708) @bdraco + +* Split DNSOutgoing/DNSIncoming/DNSMessage into zeroconf._protocol (#705) @bdraco + * Abstract DNSOutgoing ttl write into _write_ttl (#695) @bdraco * Rollback data in one call instead of poping one byte at a time in DNS Outgoing (#696) @bdraco From 18ddb8dbeef3edad3bb97131803dfecde4355467 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 17 Jun 2021 13:58:07 -1000 Subject: [PATCH 369/608] Synchronize time for fate sharing (#718) --- zeroconf/_handlers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 1d6cac4c..c42641f4 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -246,14 +246,14 @@ def response( # pylint: disable=unused-argument ucast_source = port != _MDNS_PORT known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) - now = current_time_millis() - - for question in itertools.chain(*[msg.questions for msg in msgs]): - answer_set: _AnswerWithAdditionalsType = {} - self._answer_question(question, answer_set, known_answers, now) - if not ucast_source and question.unicast: - query_res.add_qu_question_response(answer_set) - else: + + for msg in msgs: + for question in msg.questions: + answer_set: _AnswerWithAdditionalsType = {} + self._answer_question(question, answer_set, known_answers, msg.now) + if not ucast_source and question.unicast: + query_res.add_qu_question_response(answer_set) + continue if ucast_source: query_res.add_ucast_question_response(answer_set) # We always multicast as well even if its a unicast @@ -298,7 +298,7 @@ def updates_from_response(self, msg: DNSIncoming) -> None: address_adds: List[DNSAddress] = [] other_adds: List[DNSRecord] = [] removes: List[DNSRecord] = [] - now = current_time_millis() + now = msg.now for record in msg.answers: updated = True From e2d4d98db70b376c53883367b3a24c1d2510c2b5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 09:26:49 -1000 Subject: [PATCH 370/608] Relocate cache tests to tests/test_cache.py (#722) --- tests/test_cache.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_dns.py | 37 -------------------------- 2 files changed, 63 insertions(+), 37 deletions(-) create mode 100644 tests/test_cache.py diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..8580b366 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._cache. """ + +import logging +import unittest +import unittest.mock + +import zeroconf as r +from zeroconf import const + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestDNSCache(unittest.TestCase): + def test_order(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) + cached_record = cache.get(entry) + assert cached_record == record2 + + def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache + + def test_cache_empty_multiple_calls_does_not_throw(self): + record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + assert 'a' in cache.cache + cache.remove(record1) + cache.remove(record2) + # Ensure multiple removes does not throw + cache.remove(record1) + cache.remove(record2) + assert 'a' not in cache.cache diff --git a/tests/test_dns.py b/tests/test_dns.py index 557802e1..19735706 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -159,43 +159,6 @@ def test_dns_record_is_recent(self): assert record.is_recent(now + (8 * 1000)) is False -class TestDNSCache(unittest.TestCase): - def test_order(self): - record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) - cached_record = cache.get(entry) - assert cached_record == record2 - - def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): - record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - assert 'a' not in cache.cache - - def test_cache_empty_multiple_calls_does_not_throw(self): - record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') - record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') - cache = r.DNSCache() - cache.add(record1) - cache.add(record2) - assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - # Ensure multiple removes does not throw - cache.remove(record1) - cache.remove(record2) - assert 'a' not in cache.cache - - def test_dns_record_hashablity_does_not_consider_ttl(): """Test DNSRecord are hashable.""" From 33385948da9123bc9348374edce7502abd898e82 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 11:26:15 -1000 Subject: [PATCH 371/608] Fix ServiceInfo with multiple A records (#725) --- tests/test_services.py | 23 +++++++++++++++++++++++ zeroconf/_services/__init__.py | 12 +++++------- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 147c1225..9b1f205e 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -24,6 +24,7 @@ ServiceInfo, ServiceStateChange, ) +from zeroconf.aio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response @@ -947,6 +948,28 @@ def test_multiple_addresses(): assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_multiple_a_addresses(): + type_ = "_http._tcp.local." + registration_name = "multiarec.%s" % type_ + desc = {'path': '/~paulsm/'} + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + cache = aiozc.zeroconf.cache + host = "multahost.local." + record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a') + record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b') + cache.add(record1) + cache.add(record2) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) + info.load_from_cache(aiozc.zeroconf) + assert set(info.addresses) == set([b'a', b'b']) + await aiozc.async_close() + + def test_backoff(): got_query = Event() diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index c71aaed4..b2005d77 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -777,12 +777,10 @@ def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: """Get the address records from the cache.""" - address_records = [] - cached_a_record = zc.cache.get_by_details(self.server, _TYPE_A, _CLASS_IN) - if cached_a_record: - address_records.append(cached_a_record) - address_records.extend(zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN)) - return address_records + return [ + *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), + *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), + ] def load_from_cache(self, zc: 'Zeroconf') -> bool: """Populate the service info from the cache.""" @@ -844,7 +842,7 @@ def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: out = DNSOutgoing(_FLAGS_QR_QUERY) out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) return out From f91af79c8779ac235598f5584f439c78b3bdcca2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 11:36:54 -1000 Subject: [PATCH 372/608] Rename handlers and internals to make it clear what is threadsafe (#726) - It was too easy to get confused about what was threadsafe and what was not threadsafe which lead to unexpected failures. Rename functions to make it clear what will be run in the event loop and what is expected to be threadsafe --- tests/test_handlers.py | 52 +++++++++++++++++----------------- tests/test_services.py | 4 +-- zeroconf/_core.py | 8 +++--- zeroconf/_handlers.py | 36 +++++++++++++++-------- zeroconf/_services/__init__.py | 44 ++++++++++++++++++++-------- 5 files changed, 88 insertions(+), 56 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index c62f6a11..0645a24b 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -96,7 +96,7 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - multicast_out = zc.query_handler.response( + multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT )[1] _process_outgoing_packet(multicast_out) @@ -134,7 +134,7 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) _process_outgoing_packet( - zc.query_handler.response( + zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT )[1] ) @@ -232,7 +232,7 @@ def test_ptr_optimization(): # Verify we won't respond for 1s with the same multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert unicast_out is None @@ -244,7 +244,7 @@ def test_ptr_optimization(): # Verify we will now respond query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT ) assert multicast_out.id == query.id @@ -287,7 +287,7 @@ def test_any_query_for_ptr(): question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.response( + _, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0].name == type_ @@ -313,7 +313,7 @@ def test_aaaa_query(): question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.response( + _, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0].address == ipv6_address @@ -342,7 +342,7 @@ def test_unicast_response(): # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", 1234 ) for out in (unicast_out, multicast_out): @@ -419,7 +419,7 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None @@ -432,7 +432,7 @@ def _validate_complete_response(query, out): question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -445,7 +445,7 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) query.add_authorative_answer(info2.dns_pointer()) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) _validate_complete_response(query, unicast_out) @@ -458,7 +458,7 @@ def _validate_complete_response(query, out): question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None @@ -487,7 +487,7 @@ def test_known_answer_supression(): question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -498,7 +498,7 @@ def test_known_answer_supression(): generated.add_question(question) generated.add_answer_at_time(info.dns_pointer(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -510,7 +510,7 @@ def test_known_answer_supression(): question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -522,7 +522,7 @@ def test_known_answer_supression(): for dns_address in info.dns_addresses(): generated.add_answer_at_time(dns_address, now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -533,7 +533,7 @@ def test_known_answer_supression(): question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -544,7 +544,7 @@ def test_known_answer_supression(): generated.add_question(question) generated.add_answer_at_time(info.dns_service(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -556,7 +556,7 @@ def test_known_answer_supression(): question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -567,7 +567,7 @@ def test_known_answer_supression(): generated.add_question(question) generated.add_answer_at_time(info.dns_text(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -620,7 +620,7 @@ def test_multi_packet_known_answer_supression(): generated.add_answer_at_time(info3.dns_pointer(), now) packets = generated.packets() assert len(packets) > 1 - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -661,7 +661,7 @@ def test_known_answer_supression_service_type_enumeration_query(): question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -691,7 +691,7 @@ def test_known_answer_supression_service_type_enumeration_query(): now, ) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT ) assert unicast_out is None @@ -747,7 +747,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None @@ -767,7 +767,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out is None @@ -787,7 +787,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0] == ptr_record @@ -813,7 +813,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): query.add_question(question) zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache - unicast_out, multicast_out = zc.query_handler.response( + unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT ) assert multicast_out.answers[0][0] == info.dns_pointer() diff --git a/tests/test_services.py b/tests/test_services.py index 9b1f205e..f49535e5 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1242,7 +1242,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming: zc, mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), ) - zc.wait(100) + time.sleep(0.1) assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')] assert zc.get_service_info(type_, registration_name).port == 80 @@ -1252,7 +1252,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming: zc, mock_incoming_msg([info.dns_service()]), ) - zc.wait(100) + time.sleep(0.1) assert callbacks == [ ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'), diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 553b349d..240270a9 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,8 +146,8 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() - self.zc.record_manager.updates(now, list(self.zc.cache.expire(now))) - self.zc.record_manager.updates_complete() + self.zc.record_manager.async_updates(now, list(self.zc.cache.expire(now))) + self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) async def _async_close(self) -> None: @@ -565,7 +565,7 @@ def remove_listener(self, listener: RecordUpdateListener) -> None: def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers are held in the cache, and listeners are notified.""" - self.record_manager.updates_from_response(msg) + self.record_manager.async_updates_from_response(msg) def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: """Deal with incoming query packets. Provides a response if @@ -594,7 +594,7 @@ def _respond_query(self, msg: Optional[DNSIncoming], addr: str, port: int) -> No if msg: packets.append(msg) - unicast_out, multicast_out = self.query_handler.response(packets, addr, port) + unicast_out, multicast_out = self.query_handler.async_response(packets, addr, port) if unicast_out: self.async_send(unicast_out, addr, port) if multicast_out: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index c42641f4..b5279654 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -239,10 +239,14 @@ def _answer_question( if not known_answers.suppresses(dns_text): answer_set[dns_text] = set() - def response( # pylint: disable=unused-argument + def async_response( # pylint: disable=unused-argument self, msgs: List[DNSIncoming], addr: Optional[str], port: int ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: - """Deal with incoming query packets. Provides a response if possible.""" + """Deal with incoming query packets. Provides a response if possible. + + This function must be run in the event loop as it is not + threadsafe. + """ ucast_source = port != _MDNS_PORT known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) @@ -272,28 +276,36 @@ def __init__(self, zeroconf: 'Zeroconf') -> None: self.cache = zeroconf.cache self.listeners: List[RecordUpdateListener] = [] - def updates(self, now: float, rec: List[DNSRecord]) -> None: + def async_updates(self, now: float, rec: List[DNSRecord]) -> None: """Used to notify listeners of new information that has updated a record. This method must be called before the cache is updated. + + This method will be run in the event loop. """ for listener in self.listeners: - listener.update_records(self.zc, now, rec) + listener.async_update_records(self.zc, now, rec) - def updates_complete(self) -> None: + def async_updates_complete(self) -> None: """Used to notify listeners of new information that has updated a record. This method must be called after the cache is updated. + + This method will be run in the event loop. """ for listener in self.listeners: - listener.update_records_complete() + listener.async_update_records_complete() self.zc.notify_all() - def updates_from_response(self, msg: DNSIncoming) -> None: + def async_updates_from_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers - are held in the cache, and listeners are notified.""" + are held in the cache, and listeners are notified. + + This function must be run in the event loop as it is not + threadsafe. + """ updates: List[DNSRecord] = [] address_adds: List[DNSAddress] = [] other_adds: List[DNSRecord] = [] @@ -334,7 +346,7 @@ def updates_from_response(self, msg: DNSIncoming) -> None: if not updates and not address_adds and not other_adds and not removes: return - self.updates(now, updates) + self.async_updates(now, updates) # The cache adds must be processed AFTER we trigger # the updates since we compare existing data # with the new data and updating the cache @@ -355,7 +367,7 @@ def updates_from_response(self, msg: DNSIncoming) -> None: # ServiceInfo could generate an un-needed query # because the data was not yet populated. self.cache.remove_records(removes) - self.updates_complete() + self.async_updates_complete() def add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] @@ -374,8 +386,8 @@ def add_listener( if single_question.answered_by(record) and not record.is_expired(now): records.append(record) if records: - listener.update_records(self.zc, now, records) - listener.update_records_complete() + listener.async_update_records(self.zc, now, records) + listener.async_update_records_complete() self.zc.notify_all() diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index b2005d77..80fdd5f7 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -134,7 +134,7 @@ def update_record( # pylint: disable=no-self-use """ raise RuntimeError("update_record is deprecated and will be removed in a future version.") - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: """Update multiple records in one shot. All records that are received in a single packet are passed @@ -146,14 +146,18 @@ def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) - NotImplementedError in a future version. At this point the cache will not have the new records + + This method will be run in the event loop. """ for record in records: self.update_record(zc, now, record) - def update_records_complete(self) -> None: + def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers. At this point the cache will have the new records. + + This method will be run in the event loop. """ @@ -353,20 +357,24 @@ def _process_record_update(self, now: float, record: DNSRecord) -> None: if type_: self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: """Callback invoked by Zeroconf when new information arrives. Updates information required by browser in the Zeroconf cache. Ensures that there is are no unecessary duplicates in the list. + + This method will be run in the event loop. """ for record in records: self._process_record_update(now, record) - def update_records_complete(self) -> None: + def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers. At this point the cache will have the new records. + + This method will be run in the event loop. """ # Cannot use .update here since can fail with # RuntimeError: dictionary changed size during iteration @@ -677,26 +685,35 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) This method is deprecated and will be removed in a future version. update_records should be implemented instead. + + This method will be run in the event loop. """ if record is not None: - self.update_records(zc, now, [record]) + self._process_records_threadsafe(zc, now, [record]) - def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Updates service information from a DNS record.""" + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method will be run in the event loop. + """ + self._process_records_threadsafe(zc, now, records) + + def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Thread safe record updating.""" update_addresses = False for record in records: if isinstance(record, DNSService): update_addresses = True - self._process_record(record, now) + self._process_record_threadsafe(record, now) # Only update addresses if the DNSService (.server) has changed if not update_addresses: return for record in self._get_address_records_from_cache(zc): - self._process_record(record, now) + self._process_record_threadsafe(record, now) - def _process_record(self, record: DNSRecord, now: float) -> None: + def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: if record.is_expired(now): return @@ -783,7 +800,10 @@ def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: ] def load_from_cache(self, zc: 'Zeroconf') -> bool: - """Populate the service info from the cache.""" + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ now = current_time_millis() record_updates = [] cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) @@ -796,7 +816,7 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: record_updates.append(cached_txt_record) - self.update_records(zc, now, record_updates) + self._process_records_threadsafe(zc, now, record_updates) return self._is_complete @property From 9cc834d501fa5e582adeb4468b02775288e1fa11 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 11:37:30 -1000 Subject: [PATCH 373/608] Update changelog (#727) --- README.rst | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 5e5c7ea7..5e69f696 100644 --- a/README.rst +++ b/README.rst @@ -154,25 +154,31 @@ Changelog Python version eariler then 3.6 were likely broken with zeroconf already, however the version is now explictly checked. -* BREAKING CHANGE: RecordUpdateListener now uses update_records instead of update_record (#419) @bdraco +* BREAKING CHANGE: RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco This allows the listener to receive all the records that have been updated in a single transaction such as a packet or cache expiry. - update_record has been deprecated in favor of update_records + update_record has been deprecated in favor of async_update_records A compatibility shim exists to ensure classes that use RecordUpdateListener as a base class continue to have update_record called, however they should be updated as soon as possible. - A new method update_records_complete is now called on each + A new method async_update_records_complete is now called on each listener when all listeners have completed processing updates and the cache has been updated. This allows ServiceBrowsers to delay calling handlers until they are sure the cache has been updated as its a common pattern to call for ServiceInfo when a ServiceBrowser handler fires. + The async_ prefix was choosen to make it clear that these + functions run in the eventloop and should never do blocking + I/O. Before 0.32+ these functions ran in a select() loop and + should not have been doing any blocking I/O, but it was not + clear to implementors that I/O would block the loop. + * BREAKING CHANGE: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco When manually creating a zeroconf.Engine object, it is no longer started automatically. @@ -231,6 +237,10 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Fix ServiceInfo with multiple A records (#725) @bdraco + +* Synchronize time for fate sharing (#718) @bdraco + * Cleanup typing in zero._core and document ignores (#714) @bdraco * Cleanup typing in zeroconf._logger (#715) @bdraco From ceb79bd7f7bdad434cbe5b4846492cd434ea883b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:00:43 -1000 Subject: [PATCH 374/608] Add tests for the DNSCache class (#728) - There is currently a bug in the implementation where an entry can exist in two places in the cache with different TTLs. Since a known answer cannot be both expired and expired at the same time, this is a bug that needs to be fixed. --- tests/test_cache.py | 102 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/tests/test_cache.py b/tests/test_cache.py index 8580b366..aa6acf6c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -37,6 +37,38 @@ def test_order(self): cached_record = cache.get(entry) assert cached_record == record2 + def test_adding_same_record_to_cache_different_ttls(self): + """We should always get back the last entry we added if there are different TTLs. + + This ensures we only have one source of truth for TTLs as a record cannot + be both expired and not expired. + """ + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + entry = r.DNSEntry(record2) + cached_record = cache.get(entry) + assert cached_record == record2 + + @unittest.skip('This bug in the implementation needs to be fixed.') + def test_adding_same_record_to_cache_different_ttls(self): + """Verify we only get one record back. + + The last record added should replace the previous since two + records with different ttls are __eq__. This ensures we + only have one source of truth for TTLs as a record cannot + be both expired and not expired. + """ + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') + cache = r.DNSCache() + cache.add(record1) + cache.add(record2) + cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) + assert cached_records == [record2] + def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') @@ -61,3 +93,73 @@ def test_cache_empty_multiple_calls_does_not_throw(self): cache.remove(record1) cache.remove(record2) assert 'a' not in cache.cache + + +# These functions have been seen in other projects so +# we try to maintain a stable API for all the threadsafe getters +class TestDNSCacheAPI(unittest.TestCase): + def test_get(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert cache.get(record1) == record1 + assert cache.get(record2) == record2 + + def test_get_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2 + + def test_get_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) + + def test_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert set(cache.entries_with_server('ab')) == set([record1, record2]) + + def test_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert set(cache.entries_with_name('irrelevant')) == set([record1, record2]) + + def test_current_entry_with_name_and_alias(self): + record1 = r.DNSPointer( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'x.irrelevant' + ) + record2 = r.DNSPointer( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant' + ) + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1 + + def test_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.add_records([record1, record2]) + assert cache.names() == ['irrelevant'] From 88aa610274bf79aef6c74998f2bfca8c8de0dccb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:14:48 -1000 Subject: [PATCH 375/608] Fix cache handling of records with different TTLs (#729) - There should only be one unique record in the cache at a time as having multiple unique records will different TTLs in the cache can result in unexpected behavior since some functions returned all matching records and some fetched from the right side of the list to return the newest record. Intead we now store the records in a dict to ensure that the newest record always replaces the same unique record and we never have a source of truth problem determining the TTL of a record from the cache. --- tests/test_cache.py | 27 ++++-------- zeroconf/_cache.py | 103 ++++++++++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 51 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index aa6acf6c..19033b5c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -31,8 +31,7 @@ def test_order(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) cached_record = cache.get(entry) assert cached_record == record2 @@ -46,13 +45,11 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) entry = r.DNSEntry(record2) cached_record = cache.get(entry) assert cached_record == record2 - @unittest.skip('This bug in the implementation needs to be fixed.') def test_adding_same_record_to_cache_different_ttls(self): """Verify we only get one record back. @@ -64,8 +61,7 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) assert cached_records == [record2] @@ -73,25 +69,18 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache - def test_cache_empty_multiple_calls_does_not_throw(self): + def test_cache_empty_multiple_calls(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add(record1) - cache.add(record2) + cache.add_records([record1, record2]) assert 'a' in cache.cache - cache.remove(record1) - cache.remove(record2) - # Ensure multiple removes does not throw - cache.remove(record1) - cache.remove(record2) + cache.remove_records([record1, record2]) assert 'a' not in cache.cache diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 135b1884..2e07a7a4 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -27,46 +27,83 @@ from .const import _TYPE_PTR +_DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] + + +def _remove_key(cache: _DNSRecordCacheType, key: str, entry: DNSRecord) -> None: + """Remove a key from a DNSRecord cache + + This function must be run in from event loop. + """ + del cache[key][entry] + if not cache[key]: + del cache[key] + + class DNSCache: """A cache of DNS entries.""" def __init__(self) -> None: - self.cache: Dict[str, List[DNSRecord]] = {} - self.service_cache: Dict[str, List[DNSRecord]] = {} + self.cache: _DNSRecordCacheType = {} + self.service_cache: _DNSRecordCacheType = {} + + # Functions prefixed with are NOT threadsafe and must + # be run in the event loop. def add(self, entry: DNSRecord) -> None: - """Adds an entry""" - # Insert last in list, get will return newest entry - # iteration will result in last update winning - self.cache.setdefault(entry.key, []).append(entry) + """Adds an entry. + + This function must be run in from event loop. + """ + # Previously storage of records was implemented as a list + # instead a dict. Since DNSRecords are now hashable, the implementation + # uses a dict to ensure that adding a new record to the cache + # replaces any existing records that are __eq__ to each other which + # removes the risk that accessing the cache from the wrong + # direction would return the old incorrect entry. + self.cache.setdefault(entry.key, {})[entry] = entry if isinstance(entry, DNSService): - self.service_cache.setdefault(entry.server, []).append(entry) + self.service_cache.setdefault(entry.server, {})[entry] = entry def add_records(self, entries: Iterable[DNSRecord]) -> None: - """Add multiple records.""" + """Add multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.add(entry) def remove(self, entry: DNSRecord) -> None: - """Removes an entry.""" + """Removes an entry. + + This function must be run in from event loop. + """ if isinstance(entry, DNSService): - DNSCache.remove_key(self.service_cache, entry.server, entry) - DNSCache.remove_key(self.cache, entry.key, entry) + _remove_key(self.service_cache, entry.server, entry) + _remove_key(self.cache, entry.key, entry) def remove_records(self, entries: Iterable[DNSRecord]) -> None: - """Remove multiple records.""" + """Remove multiple records. + + This function must be run in from event loop. + """ for entry in entries: self.remove(entry) - @staticmethod - def remove_key(cache: dict, key: str, entry: DNSRecord) -> None: - """Forgiving remove of a cache key.""" - try: - cache[key].remove(entry) - if not cache[key]: - del cache[key] - except (KeyError, ValueError): - pass + def expire(self, now: float) -> Iterable[DNSRecord]: + """Purge expired entries from the cache. + + This function must be run in from event loop. + """ + for name in self.names(): + for record in self.entries_with_name(name): + if record.is_expired(now): + self.remove(record) + yield record + + # The below functions are threadsafe and do not need to be run in the + # event loop, however they all make copies so they significantly + # inefficent def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no @@ -77,7 +114,17 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]: return None def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSRecord]: - """Gets the first matching entry by details. Returns None if no entries match.""" + """Gets the first matching entry by details. Returns None if no entries match. + + Calling this function is not recommended as it will only + return one record even if there are multiple entries. + + For example if there are multiple A or AAAA addresses this + function will return the last one that was added to the cache + which may not be the one you expect. + + Use get_all_by_details instead. + """ return self.get(DNSEntry(name, type_, class_)) def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: @@ -87,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" - return self.service_cache.get(server, [])[:] + return list(self.service_cache.get(server, {})) def entries_with_name(self, name: str) -> List[DNSRecord]: """Returns a list of entries whose key matches the name.""" - return self.cache.get(name.lower(), [])[:] + return list(self.cache.get(name.lower(), {})) def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() @@ -107,11 +154,3 @@ def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[D def names(self) -> List[str]: """Return a copy of the list of current cache names.""" return list(self.cache) - - def expire(self, now: float) -> Iterable[DNSRecord]: - """Purge expired entries from the cache.""" - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self.remove(record) - yield record From 3503e7614fc31bbfe2c919f13689468cc73179fd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:23:26 -1000 Subject: [PATCH 376/608] Prefix cache functions that are non threadsafe with async_ (#724) --- tests/__init__.py | 4 +--- tests/test_cache.py | 28 ++++++++++++++-------------- tests/test_core.py | 22 ++++++++++++---------- tests/test_handlers.py | 25 +++++++++++++++---------- tests/test_services.py | 3 +-- zeroconf/_cache.py | 18 +++++++++--------- zeroconf/_core.py | 2 +- zeroconf/_handlers.py | 4 ++-- 8 files changed, 55 insertions(+), 51 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index 3439a044..86d7e199 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -64,6 +64,4 @@ def has_working_ipv6(): def _clear_cache(zc): - for name in zc.cache.names(): - for record in zc.cache.entries_with_name(name): - zc.cache.remove(record) + zc.cache.cache.clear() diff --git a/tests/test_cache.py b/tests/test_cache.py index 19033b5c..98da9dbb 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -31,7 +31,7 @@ def test_order(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN) cached_record = cache.get(entry) assert cached_record == record2 @@ -45,7 +45,7 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) entry = r.DNSEntry(record2) cached_record = cache.get(entry) assert cached_record == record2 @@ -61,7 +61,7 @@ def test_adding_same_record_to_cache_different_ttls(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN) assert cached_records == [record2] @@ -69,18 +69,18 @@ def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert 'a' in cache.cache - cache.remove_records([record1, record2]) + cache.async_remove_records([record1, record2]) assert 'a' not in cache.cache def test_cache_empty_multiple_calls(self): record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert 'a' in cache.cache - cache.remove_records([record1, record2]) + cache.async_remove_records([record1, record2]) assert 'a' not in cache.cache @@ -91,7 +91,7 @@ def test_get(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert cache.get(record1) == record1 assert cache.get(record2) == record2 @@ -99,14 +99,14 @@ def test_get_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2 def test_get_all_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) def test_entries_with_server(self): @@ -117,7 +117,7 @@ def test_entries_with_server(self): 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' ) cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert set(cache.entries_with_server('ab')) == set([record1, record2]) def test_entries_with_name(self): @@ -128,7 +128,7 @@ def test_entries_with_name(self): 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' ) cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert set(cache.entries_with_name('irrelevant')) == set([record1, record2]) def test_current_entry_with_name_and_alias(self): @@ -139,7 +139,7 @@ def test_current_entry_with_name_and_alias(self): 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant' ) cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1 def test_entries_with_name(self): @@ -150,5 +150,5 @@ def test_entries_with_name(self): 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' ) cache = r.DNSCache() - cache.add_records([record1, record2]) + cache.async_add_records([record1, record2]) assert cache.names() == ['irrelevant'] diff --git a/tests/test_core.py b/tests/test_core.py index f8577a6b..819bbe68 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,6 +18,7 @@ import zeroconf as r from zeroconf import _core, const, ServiceBrowser, Zeroconf, current_time_millis +from zeroconf.aio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response @@ -36,22 +37,23 @@ def teardown_module(): log.setLevel(original_logging_level) -class TestReaper(unittest.TestCase): - @unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10) - def test_reaper(self): - zeroconf = _core.Zeroconf(interfaces=['127.0.0.1']) +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_reaper(): + with unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): + assert _core._CACHE_CLEANUP_INTERVAL == 10 + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf = aiozc.zeroconf cache = zeroconf.cache original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') - zeroconf.cache.add(record_with_10s_ttl) - zeroconf.cache.add(record_with_1s_ttl) + zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - time.sleep(1) - zeroconf.notify_all() - time.sleep(0.1) + await asyncio.sleep(1.2) entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) - zeroconf.close() + await aiozc.async_close() assert entries != original_entries assert entries_with_cache != original_entries assert record_with_10s_ttl in entries diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 0645a24b..f9e7639e 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,6 +14,7 @@ import zeroconf as r from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const +from zeroconf.aio import AsyncZeroconf from . import _clear_cache, _inject_response @@ -703,10 +704,14 @@ def test_known_answer_supression_service_type_enumeration_query(): zc.close() -def test_qu_response_only_sends_additionals_if_sends_answer(): +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_qu_response_only_sends_additionals_if_sends_answer(): """Test that a QU response does not send additionals unless it sends the answer as well.""" # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf type_ = "_addtest1._tcp.local." name = "knownname" @@ -731,13 +736,13 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): ptr_record = info.dns_pointer() # Add the PTR record to the cache - zc.cache.add(ptr_record) + zc.cache.async_add_records([ptr_record]) # Add the A record to the cache with 50% ttl remaining a_record = info.dns_addresses()[0] a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl) assert not a_record.is_recent(current_time_millis()) - zc.cache.add(a_record) + zc.cache.async_add_records([a_record]) # With QU should respond to only unicast when the answer has been recently multicast # even if the additional has not been recently multicast @@ -755,10 +760,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): assert unicast_out.answers[0][0] == ptr_record # Remove the 50% A record and add a 100% A record - zc.cache.remove(a_record) + zc.cache.async_remove_records([a_record]) a_record = info.dns_addresses()[0] assert a_record.is_recent(current_time_millis()) - zc.cache.add(a_record) + zc.cache.async_add_records([a_record]) # With QU should respond to only unicast when the answer has been recently multicast # even if the additional has not been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -775,10 +780,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): assert unicast_out.answers[0][0] == ptr_record # Remove the 100% PTR record and add a 50% PTR record - zc.cache.remove(ptr_record) + zc.cache.async_remove_records([ptr_record]) ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl) assert not ptr_record.is_recent(current_time_millis()) - zc.cache.add(ptr_record) + zc.cache.async_add_records([ptr_record]) # With QU should respond to only multicast since the has less # than 75% of its ttl remaining query = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -811,7 +816,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) - zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache + zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache unicast_out, multicast_out = zc.query_handler.async_response( [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT @@ -828,4 +833,4 @@ def test_qu_response_only_sends_additionals_if_sends_answer(): # unregister zc.registry.remove(info) - zc.close() + await aiozc.async_close() diff --git a/tests/test_services.py b/tests/test_services.py index f49535e5..305c3d96 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -960,8 +960,7 @@ async def test_multiple_a_addresses(): host = "multahost.local." record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a') record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b') - cache.add(record1) - cache.add(record2) + cache.async_add_records([record1, record2]) # New kwarg way info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 2e07a7a4..e3dd4593 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -47,10 +47,10 @@ def __init__(self) -> None: self.cache: _DNSRecordCacheType = {} self.service_cache: _DNSRecordCacheType = {} - # Functions prefixed with are NOT threadsafe and must + # Functions prefixed with async_ are NOT threadsafe and must # be run in the event loop. - def add(self, entry: DNSRecord) -> None: + def _async_add(self, entry: DNSRecord) -> None: """Adds an entry. This function must be run in from event loop. @@ -65,15 +65,15 @@ def add(self, entry: DNSRecord) -> None: if isinstance(entry, DNSService): self.service_cache.setdefault(entry.server, {})[entry] = entry - def add_records(self, entries: Iterable[DNSRecord]) -> None: + def async_add_records(self, entries: Iterable[DNSRecord]) -> None: """Add multiple records. This function must be run in from event loop. """ for entry in entries: - self.add(entry) + self._async_add(entry) - def remove(self, entry: DNSRecord) -> None: + def _async_remove(self, entry: DNSRecord) -> None: """Removes an entry. This function must be run in from event loop. @@ -82,15 +82,15 @@ def remove(self, entry: DNSRecord) -> None: _remove_key(self.service_cache, entry.server, entry) _remove_key(self.cache, entry.key, entry) - def remove_records(self, entries: Iterable[DNSRecord]) -> None: + def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: """Remove multiple records. This function must be run in from event loop. """ for entry in entries: - self.remove(entry) + self._async_remove(entry) - def expire(self, now: float) -> Iterable[DNSRecord]: + def async_expire(self, now: float) -> Iterable[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. @@ -98,7 +98,7 @@ def expire(self, now: float) -> Iterable[DNSRecord]: for name in self.names(): for record in self.entries_with_name(name): if record.is_expired(now): - self.remove(record) + self._async_remove(record) yield record # The below functions are threadsafe and do not need to be run in the diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 240270a9..e5c92ce3 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() - self.zc.record_manager.async_updates(now, list(self.zc.cache.expire(now))) + self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now))) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index b5279654..dd1a9ca0 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -362,11 +362,11 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # zc.get_service_info will see the cached value # but ONLY after all the record updates have been # processsed. - self.cache.add_records(itertools.chain(address_adds, other_adds)) + self.cache.async_add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. - self.cache.remove_records(removes) + self.cache.async_remove_records(removes) self.async_updates_complete() def add_listener( From 733f79d28c7dd4500a1598b279ee638ead8bdd55 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:23:35 -1000 Subject: [PATCH 377/608] Update changelog (#730) --- README.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.rst b/README.rst index 5e69f696..3e3d0e64 100644 --- a/README.rst +++ b/README.rst @@ -143,6 +143,10 @@ Changelog 0.32.0 (Unreleased) =================== +Documentation for breaking changes era on the side of the caution and likely +overstates the risk on many of these. If you are not accessing zeroconf internals, +you can likely not be concerned with the breaking changes below: + * BREAKING CHANGE: zeroconf.asyncio has been renamed zeroconf.aio (#503) @bdraco The asyncio name could shadow system asyncio in some cases. If @@ -201,6 +205,19 @@ Changelog These functions are not intended to be used by external callers and the API is not likely to be stable in the future +* BREAKING CHANGE: Prefix cache functions that are non threadsafe with async_ (#724) @bdraco + + Adding (`zc.cache.add` -> `zc.cache.async_add_records`), removing (`zc.cache.remove` -> + `zc.cache.async_remove_records`), and expiring the cache (`zc.cache.expire` -> + `zc.cache.async_expire`) the cache is not threadsafe and must be called from the + event loop (previously the Engine select loop before 0.32) + + These functions should only be run from the event loop as they are NOT thread safe. + + We never expect these functions will be called externally, however it was possible so this + is documented as a breaking change. It is highly recommended that external callers do not + modify the cache directly. + * TRAFFIC REDUCTION: Add support for handling QU questions (#621) @bdraco Implements RFC 6762 sec 5.4: @@ -237,6 +254,10 @@ Changelog * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Fix cache handling of records with different TTLs (#729) @bdraco + +* Rename handlers and internals to make it clear what is threadsafe (#726) @bdraco + * Fix ServiceInfo with multiple A records (#725) @bdraco * Synchronize time for fate sharing (#718) @bdraco From 3ee9b650bedbe61d59838897f653ad43a6d51910 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 12:45:56 -1000 Subject: [PATCH 378/608] Fix server cache to be case-insensitive (#731) --- tests/test_cache.py | 4 +++- tests/test_services.py | 2 ++ zeroconf/_cache.py | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 98da9dbb..7c75866b 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -119,6 +119,7 @@ def test_entries_with_server(self): cache = r.DNSCache() cache.async_add_records([record1, record2]) assert set(cache.entries_with_server('ab')) == set([record1, record2]) + assert set(cache.entries_with_server('AB')) == set([record1, record2]) def test_entries_with_name(self): record1 = r.DNSService( @@ -130,6 +131,7 @@ def test_entries_with_name(self): cache = r.DNSCache() cache.async_add_records([record1, record2]) assert set(cache.entries_with_name('irrelevant')) == set([record1, record2]) + assert set(cache.entries_with_name('Irrelevant')) == set([record1, record2]) def test_current_entry_with_name_and_alias(self): record1 = r.DNSPointer( @@ -142,7 +144,7 @@ def test_current_entry_with_name_and_alias(self): cache.async_add_records([record1, record2]) assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1 - def test_entries_with_name(self): + def test_name(self): record1 = r.DNSService( 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' ) diff --git a/tests/test_services.py b/tests/test_services.py index 305c3d96..5d72a1aa 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -857,6 +857,8 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi # service A updated service_updated_event.clear() service_address = '10.0.1.3' + # Verify we match on uppercase + service_server = service_server.upper() _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) service_updated_event.wait(wait_time) assert service_added_count == 1 diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index e3dd4593..12e4aa64 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -134,11 +134,11 @@ def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSReco def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" - return list(self.service_cache.get(server, {})) + return list(self.service_cache.get(server.lower(), [])) def entries_with_name(self, name: str) -> List[DNSRecord]: """Returns a list of entries whose key matches the name.""" - return list(self.cache.get(name.lower(), {})) + return list(self.cache.get(name.lower(), [])) def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]: now = current_time_millis() From 50af94493ff6bf5d21445eaa80d3a96f348b0d11 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 17:09:05 -1000 Subject: [PATCH 379/608] Add test coverage to ensure the cache flush bit is properly handled (#734) --- tests/test_handlers.py | 81 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index f9e7639e..92d95fa2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._handlers """ +import asyncio import logging import pytest import socket @@ -834,3 +835,83 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): # unregister zc.registry.remove(info) await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_cache_flush_bit(): + """Test that the cache flush bit sets the TTL to one for matching records.""" + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + zc.cache.async_add_records([info.dns_pointer(), a_record, info.dns_text(), info.dns_service()]) + + info.addresses = [socket.inet_aton("10.0.1.5"), socket.inet_aton("10.0.1.6")] + new_records = info.dns_addresses() + for new_record in new_records: + assert new_record.unique is True + + original_a_record = zc.cache.get(a_record) + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert zc.cache.get(a_record) is original_a_record + assert original_a_record.ttl != 1 + for record in new_records: + assert zc.cache.get(record) is not None + + original_a_record.created = current_time_millis() - 1001 + + # Do the run within 1s to verify the original record is not going to be expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in new_records: + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + assert original_a_record.ttl == 1 + for record in new_records: + assert zc.cache.get(record) is not None + + cached_records = [zc.cache.get(record) for record in new_records] + for record in cached_records: + record.created = current_time_millis() - 1001 + + fresh_address = socket.inet_aton("4.4.4.4") + info.addresses = [fresh_address] + # Do the run within 1s to verify the two new records get marked as expired + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) + for answer in info.dns_addresses(): + out.add_answer_at_time(answer, 0) + for packet in out.packets(): + zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) + for record in cached_records: + assert record.ttl == 1 + + for entry in zc.cache.get_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + if entry.address == fresh_address: + assert entry.ttl > 1 + else: + assert entry.ttl == 1 + + # Wait for the ttl 1 records to expire + await asyncio.sleep(1.01) + + loaded_info = r.ServiceInfo(type_, registration_name) + loaded_info.load_from_cache(zc) + assert loaded_info.addresses == info.addresses + + await aiozc.async_close() From c035925f47732a889c76a2ff0989b92c6687c950 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 17:26:37 -1000 Subject: [PATCH 380/608] Switch to using DNSRRSet in RecordManager (#735) --- zeroconf/_dns.py | 7 ++++++ zeroconf/_handlers.py | 56 ++++++++++++++++++++++++------------------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 5b7fe70f..66892d52 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -423,3 +423,10 @@ def suppresses(self, record: DNSRecord) -> bool: self._lookup = {record: record for record in self._records} other = self._lookup.get(record) return bool(other and other.ttl > (record.ttl / 2)) + + def __contains__(self, record: DNSRecord) -> bool: + """Returns true if the rrset contains the record.""" + if self._lookup is None: + # Build the hash table so we can lookup the record independent of the ttl + self._lookup = {record: record for record in self._records} + return record in self._lookup diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index dd1a9ca0..03495b4e 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -311,25 +311,14 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: other_adds: List[DNSRecord] = [] removes: List[DNSRecord] = [] now = msg.now - for record in msg.answers: - - updated = True + unique_types: Set[Tuple[str, int, int]] = set() + for record in msg.answers: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 - # rfc6762#section-10.2 para 2 - # Since unique is set, all old records with that name, rrtype, - # and rrclass that were received more than one second ago are declared - # invalid, and marked to expire from the cache in one second. - for entry in self.cache.get_all_by_details(record.name, record.type, record.class_): - if entry == record: - updated = False - if record.created - entry.created > 1000 and entry not in msg.answers: - # Expire in 1s - entry.set_created_ttl(now, 1) - - expired = record.is_expired(now) + unique_types.add((record.name, record.type, record.class_)) + maybe_entry = self.cache.get(record) - if not expired: + if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) else: @@ -337,16 +326,18 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: address_adds.append(record) else: other_adds.append(record) - if updated: - updates.append(record) + updates.append(record) + # This is likely a goodbye since the record is + # expired and exists in the cache elif maybe_entry is not None: updates.append(record) removes.append(record) - if not updates and not address_adds and not other_adds and not removes: - return + if unique_types: + self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) - self.async_updates(now, updates) + if updates: + self.async_updates(now, updates) # The cache adds must be processed AFTER we trigger # the updates since we compare existing data # with the new data and updating the cache @@ -362,12 +353,29 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # zc.get_service_info will see the cached value # but ONLY after all the record updates have been # processsed. - self.cache.async_add_records(itertools.chain(address_adds, other_adds)) + if other_adds or address_adds: + self.cache.async_add_records(itertools.chain(address_adds, other_adds)) # Removes are processed last since # ServiceInfo could generate an un-needed query # because the data was not yet populated. - self.cache.async_remove_records(removes) - self.async_updates_complete() + if removes: + self.cache.async_remove_records(removes) + if updates: + self.async_updates_complete() + + def _async_mark_unique_cached_records_older_than_1s_to_expire( + self, unique_types: Set[Tuple[str, int, int]], answers: List[DNSRecord], now: float + ) -> None: + # rfc6762#section-10.2 para 2 + # Since unique is set, all old records with that name, rrtype, + # and rrclass that were received more than one second ago are declared + # invalid, and marked to expire from the cache in one second. + answers_rrset = DNSRRSet(answers) + for name, type_, class_ in unique_types: + for entry in self.cache.get_all_by_details(name, type_, class_): + if (now - entry.created > 1000) and entry not in answers_rrset: + # Expire in 1s + entry.set_created_ttl(now, 1) def add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] From 9d31245f9ed4f6b1f7d9d7c51daf0ca394fd208f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 17:44:05 -1000 Subject: [PATCH 381/608] Add fast cache lookup functions (#732) --- tests/test_cache.py | 47 +++++++++++++++++++- tests/test_handlers.py | 12 ++--- zeroconf/_cache.py | 81 ++++++++++++++++++++++++++++------ zeroconf/_core.py | 2 +- zeroconf/_dns.py | 11 +++-- zeroconf/_handlers.py | 12 ++--- zeroconf/_services/__init__.py | 9 ++-- 7 files changed, 136 insertions(+), 38 deletions(-) diff --git a/tests/test_cache.py b/tests/test_cache.py index 7c75866b..4b3a8a18 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -84,16 +84,61 @@ def test_cache_empty_multiple_calls(self): assert 'a' not in cache.cache +class TestDNSAsyncCacheAPI(unittest.TestCase): + def test_async_get_unique(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert cache.async_get_unique(record1) == record1 + assert cache.async_get_unique(record2) == record2 + + def test_async_all_by_details(self): + record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') + record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) + + def test_async_entries_with_server(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_server('ab')) == set([record1, record2]) + assert set(cache.async_entries_with_server('AB')) == set([record1, record2]) + + def test_async_entries_with_name(self): + record1 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 85, 'ab' + ) + record2 = r.DNSService( + 'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab' + ) + cache = r.DNSCache() + cache.async_add_records([record1, record2]) + assert set(cache.async_entries_with_name('irrelevant')) == set([record1, record2]) + assert set(cache.async_entries_with_name('Irrelevant')) == set([record1, record2]) + + # These functions have been seen in other projects so # we try to maintain a stable API for all the threadsafe getters class TestDNSCacheAPI(unittest.TestCase): def test_get(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') + record3 = r.DNSAddress('a', const._TYPE_AAAA, const._CLASS_IN, 1, b'ipv6') cache = r.DNSCache() - cache.async_add_records([record1, record2]) + cache.async_add_records([record1, record2, record3]) assert cache.get(record1) == record1 assert cache.get(record2) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_A, const._CLASS_IN)) == record2 + assert cache.get(r.DNSEntry('a', const._TYPE_AAAA, const._CLASS_IN)) == record3 + assert cache.get(r.DNSEntry('notthere', const._TYPE_A, const._CLASS_IN)) is None def test_get_by_details(self): record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a') diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 92d95fa2..ddd8ffa4 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -862,17 +862,17 @@ async def test_cache_flush_bit(): for new_record in new_records: assert new_record.unique is True - original_a_record = zc.cache.get(a_record) + original_a_record = zc.cache.async_get_unique(a_record) # Do the run within 1s to verify the original record is not going to be expired out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA, multicast=True) for answer in new_records: out.add_answer_at_time(answer, 0) for packet in out.packets(): zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) - assert zc.cache.get(a_record) is original_a_record + assert zc.cache.async_get_unique(a_record) is original_a_record assert original_a_record.ttl != 1 for record in new_records: - assert zc.cache.get(record) is not None + assert zc.cache.async_get_unique(record) is not None original_a_record.created = current_time_millis() - 1001 @@ -884,9 +884,9 @@ async def test_cache_flush_bit(): zc.record_manager.async_updates_from_response(r.DNSIncoming(packet)) assert original_a_record.ttl == 1 for record in new_records: - assert zc.cache.get(record) is not None + assert zc.cache.async_get_unique(record) is not None - cached_records = [zc.cache.get(record) for record in new_records] + cached_records = [zc.cache.async_get_unique(record) for record in new_records] for record in cached_records: record.created = current_time_millis() - 1001 @@ -901,7 +901,7 @@ async def test_cache_flush_bit(): for record in cached_records: assert record.ttl == 1 - for entry in zc.cache.get_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): + for entry in zc.cache.async_all_by_details(server_name, const._TYPE_A, const._CLASS_IN): if entry.address == fresh_address: assert entry.ttl > 1 else: diff --git a/zeroconf/_cache.py b/zeroconf/_cache.py index 12e4aa64..24b6a233 100644 --- a/zeroconf/_cache.py +++ b/zeroconf/_cache.py @@ -20,13 +20,24 @@ USA """ -from typing import Dict, Iterable, List, Optional, cast - -from ._dns import DNSEntry, DNSPointer, DNSRecord, DNSService +import itertools +from typing import Dict, Iterable, Iterator, List, Optional, Union, cast + +from ._dns import ( + DNSAddress, + DNSEntry, + DNSHinfo, + DNSPointer, + DNSRecord, + DNSService, + DNSText, + dns_entry_matches, +) from ._utils.time import current_time_millis from .const import _TYPE_PTR - +_UNIQUE_RECORD_TYPES = (DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService) +_UniqueRecordsType = Union[DNSAddress, DNSHinfo, DNSPointer, DNSText, DNSService] _DNSRecordCacheType = Dict[str, Dict[DNSRecord, DNSRecord]] @@ -90,16 +101,50 @@ def async_remove_records(self, entries: Iterable[DNSRecord]) -> None: for entry in entries: self._async_remove(entry) - def async_expire(self, now: float) -> Iterable[DNSRecord]: + def async_expire(self, now: float) -> List[DNSRecord]: """Purge expired entries from the cache. This function must be run in from event loop. """ - for name in self.names(): - for record in self.entries_with_name(name): - if record.is_expired(now): - self._async_remove(record) - yield record + expired = [record for record in itertools.chain(*self.cache.values()) if record.is_expired(now)] + self.async_remove_records(expired) + return expired + + def async_get_unique(self, entry: _UniqueRecordsType) -> Optional[DNSRecord]: + """Gets a unique entry by key. Will return None if there is no + matching entry. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(entry.key, {}).get(entry) + + def async_all_by_details(self, name: str, type_: int, class_: int) -> Iterator[DNSRecord]: + """Gets all matching entries by details. + + This function is not threadsafe and must be called from + the event loop. + """ + key = name.lower() + for entry in self.cache.get(key, []): + if dns_entry_matches(entry, key, type_, class_): + yield entry + + def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the name. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.cache.get(name.lower(), {}) + + def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]: + """Returns a dict of entries whose key matches the server. + + This function is not threadsafe and must be called from + the event loop. + """ + return self.service_cache.get(name.lower(), {}) # The below functions are threadsafe and do not need to be run in the # event loop, however they all make copies so they significantly @@ -108,7 +153,9 @@ def async_expire(self, now: float) -> Iterable[DNSRecord]: def get(self, entry: DNSEntry) -> Optional[DNSRecord]: """Gets an entry by key. Will return None if there is no matching entry.""" - for cached_entry in reversed(self.entries_with_name(entry.key)): + if isinstance(entry, _UNIQUE_RECORD_TYPES): + return self.cache.get(entry.key, {}).get(entry) + for cached_entry in reversed(list(self.cache.get(entry.key, []))): if entry.__eq__(cached_entry): return cached_entry return None @@ -125,12 +172,18 @@ def get_by_details(self, name: str, type_: int, class_: int) -> Optional[DNSReco Use get_all_by_details instead. """ - return self.get(DNSEntry(name, type_, class_)) + key = name.lower() + for cached_entry in reversed(list(self.cache.get(key, []))): + if dns_entry_matches(cached_entry, key, type_, class_): + return cached_entry + return None def get_all_by_details(self, name: str, type_: int, class_: int) -> List[DNSRecord]: """Gets all matching entries by details.""" - match_entry = DNSEntry(name, type_, class_) - return [entry for entry in self.entries_with_name(name) if match_entry.__eq__(entry)] + key = name.lower() + return [ + entry for entry in list(self.cache.get(key, [])) if dns_entry_matches(entry, key, type_, class_) + ] def entries_with_server(self, server: str) -> List[DNSRecord]: """Returns a list of entries whose server matches the name.""" diff --git a/zeroconf/_core.py b/zeroconf/_core.py index e5c92ce3..a7910591 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() - self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now))) + self.zc.record_manager.async_updates(now, self.zc.cache.async_expire(now)) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 66892d52..e656bc51 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -49,6 +49,10 @@ from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import +def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: + return key == record.key and type_ == record.type and class_ == record.class_ + + class DNSEntry: """A DNS entry""" @@ -66,12 +70,7 @@ def _entry_tuple(self) -> Tuple[str, int, int]: def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" - return ( - self.key == other.key - and self.type == other.type - and self.class_ == other.class_ - and isinstance(other, DNSEntry) - ) + return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) @staticmethod def get_class_(class_: int) -> str: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 03495b4e..66b8862f 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,9 +21,9 @@ """ import itertools -from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union +from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from ._cache import DNSCache +from ._cache import DNSCache, _UniqueRecordsType from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing @@ -141,7 +141,7 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: SHOULD instead multicast the response so as to keep all the peer caches up to date """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: @@ -149,7 +149,7 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ - maybe_entry = self._cache.get(record) + maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) return bool(maybe_entry and self._now - maybe_entry.created < 1000) @@ -317,7 +317,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 unique_types.add((record.name, record.type, record.class_)) - maybe_entry = self.cache.get(record) + maybe_entry = self.cache.async_get_unique(cast(_UniqueRecordsType, record)) if not record.is_expired(now): if maybe_entry is not None: maybe_entry.reset_ttl(record) @@ -372,7 +372,7 @@ def _async_mark_unique_cached_records_older_than_1s_to_expire( # invalid, and marked to expire from the cache in one second. answers_rrset = DNSRRSet(answers) for name, type_, class_ in unique_types: - for entry in self.cache.get_all_by_details(name, type_, class_): + for entry in self.cache.async_all_by_details(name, type_, class_): if (now - entry.created > 1000) and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 80fdd5f7..306b6999 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -27,6 +27,7 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from .._cache import _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing @@ -316,7 +317,7 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def _process_record_update(self, now: float, record: DNSRecord) -> None: + def _async_process_record_update(self, now: float, record: DNSRecord) -> None: """Process a single record update from a batch of updates.""" expired = record.is_expired(now) @@ -340,12 +341,12 @@ def _process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.get(record): + if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): return if isinstance(record, DNSAddress): # Iterate through the DNSCache and callback any services that use this address - for service in self.zc.cache.entries_with_server(record.name): + for service in self.zc.cache.async_entries_with_server(record.name): type_ = self._record_matching_type(service) if type_: self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) @@ -367,7 +368,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSReco This method will be run in the event loop. """ for record in records: - self._process_record_update(now, record) + self._async_process_record_update(now, record) def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers. From 35ac7a39d1fab00898ed6075e7e930424716b627 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 18:12:00 -1000 Subject: [PATCH 382/608] Breakout ServiceBrowser handler from listener creation (#736) --- tests/test_services.py | 123 +++++++++++++++++++++++++++++++++ zeroconf/_services/__init__.py | 49 +++++++------ 2 files changed, 149 insertions(+), 23 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 5d72a1aa..f972f9d2 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -515,6 +515,11 @@ def _mock_get_expiration_time(self, percent): assert service_added_count == 3 assert service_removed_count == 0 + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0), + ) + # all three services removed _inject_response( zeroconf, @@ -1265,6 +1270,124 @@ def mock_incoming_msg(records) -> r.DNSIncoming: zc.close() +def test_service_browser_listeners_update_service(): + """Test that the ServiceBrowser ServiceListener that implements update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ] + browser.cancel() + + zc.close() + + +def test_service_browser_listeners_no_update_service(): + """Test that the ServiceBrowser ServiceListener that does not implement update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener: + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ] + browser.cancel() + + zc.close() + + def test_changing_name_updates_serviceinfo_key(): """Verify a name change will adjust the underlying key value.""" type_ = "_homeassistant._tcp.local." diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 306b6999..20ed6651 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -223,6 +223,31 @@ def _group_ptr_queries_with_known_answers( return [query_bucket.out for query_bucket in query_buckets] +def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: + """Generate a service_state_changed handlers from a listener.""" + + def on_change( + zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange + ) -> None: + assert listener is not None + args = (zeroconf, service_type, name) + if state_change is ServiceStateChange.Added: + listener.add_service(*args) + elif state_change is ServiceStateChange.Removed: + listener.remove_service(*args) + elif state_change is ServiceStateChange.Updated: + if hasattr(listener, 'update_service'): + listener.update_service(*args) + else: + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) + + return on_change + + class _ServiceBrowserBase(RecordUpdateListener): """Base class for ServiceBrowser.""" @@ -263,29 +288,7 @@ def __init__( handlers = cast(List[Callable[..., None]], handlers or []) if listener: - - def on_change( - zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange - ) -> None: - assert listener is not None - args = (zeroconf, service_type, name) - if state_change is ServiceStateChange.Added: - listener.add_service(*args) - elif state_change is ServiceStateChange.Removed: - listener.remove_service(*args) - elif state_change is ServiceStateChange.Updated: - if hasattr(listener, 'update_service'): - listener.update_service(*args) - else: - warnings.warn( - "%r has no update_service method. Provide one (it can be empty if you " - "don't care about the updates), it'll become mandatory." % (listener,), - FutureWarning, - ) - else: - raise NotImplementedError(state_change) - - handlers.append(on_change) + handlers.append(_service_state_changed_from_listener(listener)) for h in handlers: self.service_state_changed.register_handler(h) From 5feda7e318f7d164d2b04b2d243a804372517da6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 18:48:26 -1000 Subject: [PATCH 383/608] Remove second level caching from ServiceBrowsers (#737) --- zeroconf/_services/__init__.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 20ed6651..9334bafc 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -223,6 +223,21 @@ def _group_ptr_queries_with_known_answers( return [query_bucket.out for query_bucket in query_buckets] +def generate_service_query( + zc: 'Zeroconf', now: float, types_: List[str], multicast: bool = True +) -> List[DNSOutgoing]: + """Generate a service query for sending with zeroconf.send.""" + questions_with_known_answers: _QuestionWithKnownAnswers = {} + for type_ in types_: + question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) + questions_with_known_answers[question] = set( + cast(DNSPointer, record) + for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) + if not record.is_stale(now) + ) + return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) + + def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: """Generate a service_state_changed handlers from a listener.""" @@ -271,7 +286,6 @@ def __init__( self.addr = addr self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - self._services: Dict[str, Dict[str, DNSPointer]] = {check_type_: {} for check_type_ in self.types} current_time = current_time_millis() self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} @@ -327,17 +341,14 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None: if isinstance(record, DNSPointer): if record.name not in self.types: return - service_key = record.alias.lower() - services_by_type = self._services[record.name] - old_record = services_by_type.get(service_key) + old_record = self.zc.cache.async_get_unique( + DNSPointer(record.name, _TYPE_PTR, _CLASS_IN, 0, record.alias) + ) if old_record is None: - services_by_type[service_key] = record self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) elif expired: - del services_by_type[service_key] self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: - old_record.reset_ttl(record) expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) if expires < self._next_time[record.name]: self._next_time[record.name] = expires @@ -407,18 +418,17 @@ def generate_ready_queries(self) -> List[DNSOutgoing]: if min(self._next_time.values()) > now: return [] - questions_with_known_answers: _QuestionWithKnownAnswers = {} + ready_types = [] for type_, due in self._next_time.items(): if due > now: continue - questions_with_known_answers[DNSQuestion(type_, _TYPE_PTR, _CLASS_IN)] = set( - record for record in self._services[type_].values() if not record.is_stale(now) - ) + + ready_types.append(type_) self._next_time[type_] = now + self._delay[type_] self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - return _group_ptr_queries_with_known_answers(now, self.multicast, questions_with_known_answers) + return generate_service_query(self.zc, now, ready_types, self.multicast) def _seconds_to_wait(self) -> Optional[float]: """Returns the number of seconds to wait for the next event.""" From e227d6e4c337ef9d5aa626c41587a8046313e416 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 20:30:30 -1000 Subject: [PATCH 384/608] Fix flakey cache bit flush test (#739) --- tests/test_handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ddd8ffa4..b86d253d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -908,7 +908,7 @@ async def test_cache_flush_bit(): assert entry.ttl == 1 # Wait for the ttl 1 records to expire - await asyncio.sleep(1.01) + await asyncio.sleep(1.1) loaded_info = r.ServiceInfo(type_, registration_name) loaded_info.load_from_cache(zc) From c8e15dd2bb5f6d2eb3a8ef5f26ad044517b70c47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 21:24:23 -1000 Subject: [PATCH 385/608] Run question answer callbacks from add_listener in the event loop (#740) --- tests/test_core.py | 12 +++++++++--- zeroconf/_handlers.py | 35 +++++++++++++++++++++++----------- zeroconf/_services/__init__.py | 8 ++------ zeroconf/aio.py | 1 - 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 819bbe68..1f0884f0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -249,10 +249,14 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS zeroconf.close() -def test_notify_listeners(): +# This test uses asyncio because it needs to verify the listeners +# run in the event loop +@pytest.mark.asyncio +async def test_notify_listeners(): """Test adding and removing notify listeners.""" # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf notify_called = 0 class TestNotifyListener(r.NotifyListener): @@ -274,6 +278,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) browser.cancel() + await asyncio.sleep(0) # flush out any call_soon_threadsafe assert notify_called zc.remove_notify_listener(notify_listener) @@ -281,10 +286,11 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # start a browser browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) browser.cancel() + await asyncio.sleep(0) # flush out any call_soon_threadsafe assert not notify_called - zc.close() + await aiozc.async_close() def test_generate_service_query_set_qu_bit(): diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 66b8862f..476217f6 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -385,18 +385,31 @@ def add_listener( answer the question(s).""" self.listeners.append(listener) - if question is not None: - now = current_time_millis() - records = [] - questions = [question] if isinstance(question, DNSQuestion) else question - for single_question in questions: - for record in self.cache.entries_with_name(single_question.name): - if single_question.answered_by(record) and not record.is_expired(now): - records.append(record) - if records: - listener.async_update_records(self.zc, now, records) - listener.async_update_records_complete() + if question is None: + self.zc.notify_all() + return + questions = [question] if isinstance(question, DNSQuestion) else question + assert self.zc.loop is not None + self.zc.loop.call_soon_threadsafe(self._async_update_matching_records, listener, questions) + + def _async_update_matching_records( + self, listener: RecordUpdateListener, questions: List[DNSQuestion] + ) -> None: + """Calls back any existing entries in the cache that answer the question. + + This function must be run from the event loop. + """ + now = current_time_millis() + records = [] + for question in questions: + for record in self.cache.async_entries_with_name(question.name): + if not record.is_expired(now) and question.answered_by(record): + records.append(record) + if not records: + return + listener.async_update_records(self.zc, now, records) + listener.async_update_records_complete() self.zc.notify_all() def remove_listener(self, listener: RecordUpdateListener) -> None: diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 9334bafc..04288a69 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -307,6 +307,8 @@ def __init__( for h in handlers: self.service_state_changed.register_handler(h) + self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + @property def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface @@ -406,11 +408,6 @@ def cancel(self) -> None: self.done = True self.zc.remove_listener(self) - def run(self) -> None: - """Run the browser.""" - questions = [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types] - self.zc.add_listener(self, questions) - def generate_ready_queries(self) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" now = current_time_millis() @@ -480,7 +477,6 @@ def cancel(self) -> None: def run(self) -> None: """Run the browser thread.""" - super().run() while True: timeout = self._seconds_to_wait() if timeout: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index ae57d014..d5414d13 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -146,7 +146,6 @@ async def async_cancel(self) -> None: async def async_run(self) -> None: """Run the browser task.""" - self.run() await self.aiozc.zeroconf.async_wait_for_start() while True: timeout = self._seconds_to_wait() From f0d727bd9addd6dab373b75008f04a6f8547928b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 21:30:53 -1000 Subject: [PATCH 386/608] Relocate ServiceInfo to zeroconf._services.info (#741) --- tests/test_aio.py | 3 +- tests/test_services.py | 7 +- zeroconf/__init__.py | 6 +- zeroconf/_core.py | 9 +- zeroconf/_services/__init__.py | 421 +----------------------------- zeroconf/_services/info.py | 458 +++++++++++++++++++++++++++++++++ zeroconf/_services/registry.py | 2 +- zeroconf/aio.py | 3 +- 8 files changed, 472 insertions(+), 437 deletions(-) create mode 100644 zeroconf/_services/info.py diff --git a/tests/test_aio.py b/tests/test_aio.py index 47c1e2d9..e4144250 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -16,7 +16,8 @@ from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered -from zeroconf._services import ServiceInfo, ServiceListener +from zeroconf._services import ServiceListener +from zeroconf._services.info import ServiceInfo from zeroconf._utils.time import current_time_millis from . import _clear_cache diff --git a/tests/test_services.py b/tests/test_services.py index f972f9d2..2a077329 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -19,11 +19,8 @@ from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis import zeroconf._services as s from zeroconf import Zeroconf -from zeroconf._services import ( - ServiceBrowser, - ServiceInfo, - ServiceStateChange, -) +from zeroconf._services import ServiceBrowser, ServiceStateChange +from zeroconf._services.info import ServiceInfo from zeroconf.aio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index ab2b0993..e61a7119 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -46,15 +46,17 @@ ) from ._protocol import DNSIncoming, DNSOutgoing # noqa # import needed for backwards compat from ._services import ( # noqa # import needed for backwards compat - instance_name_from_service_info, Signal, SignalRegistrationInterface, RecordUpdateListener, ServiceBrowser, - ServiceInfo, ServiceListener, ServiceStateChange, ) +from ._services.info import ( # noqa # import needed for backwards compat + instance_name_from_service_info, + ServiceInfo, +) from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat from ._services.types import ZeroconfServiceTypes from ._utils.name import service_type_name # noqa # import needed for backwards compat diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a7910591..675d7169 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -37,13 +37,8 @@ from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing -from ._services import ( - RecordUpdateListener, - ServiceBrowser, - ServiceInfo, - ServiceListener, - instance_name_from_service_info, -) +from ._services import RecordUpdateListener, ServiceBrowser, ServiceListener +from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._utils.aio import get_running_loop from ._utils.name import service_type_name diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 04288a69..111ea448 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -21,44 +21,28 @@ """ import enum -import socket import threading import warnings from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from .._cache import _UniqueRecordsType -from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText -from .._exceptions import BadTypeInNameException +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord from .._protocol import DNSOutgoing from .._utils.name import service_type_name -from .._utils.net import ( - IPVersion, - _encode_address, - _is_v6_address, -) -from .._utils.struct import int2byte from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, _CLASS_IN, - _CLASS_UNIQUE, - _DNS_HOST_TTL, - _DNS_OTHER_TTL, _DNS_PACKET_HEADER_LEN, _EXPIRE_REFRESH_TIME_PERCENT, _FLAGS_QR_QUERY, - _LISTENER_TIME, _MAX_MSG_TYPICAL, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT, - _TYPE_A, - _TYPE_AAAA, _TYPE_PTR, - _TYPE_SRV, - _TYPE_TXT, ) @@ -77,16 +61,6 @@ class ServiceStateChange(enum.Enum): Updated = 3 -def instance_name_from_service_info(info: "ServiceInfo") -> str: - """Calculate the instance name from the ServiceInfo.""" - # This is kind of funky because of the subtype based tests - # need to make subtypes a first class citizen - service_name = service_type_name(info.name) - if not info.type.endswith(service_name): - raise BadTypeInNameException - return info.name[: -len(service_name) - 1] - - class ServiceListener: def add_service(self, zc: 'Zeroconf', type_: str, name: str) -> None: raise NotImplementedError() @@ -505,396 +479,3 @@ def run(self) -> None: name=name_type[0], state_change=state_change, ) - - -class ServiceInfo(RecordUpdateListener): - """Service information. - - Constructor parameters are as follows: - - * `type_`: fully qualified service type name - * `name`: fully qualified service name - * `port`: port that the service runs on - * `weight`: weight of the service - * `priority`: priority of the service - * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). - converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to - value-less attributes. - * `server`: fully qualified name for service host (defaults to name) - * `host_ttl`: ttl used for A/SRV records - * `other_ttl`: ttl used for PTR/TXT records - * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, - or in parsed form as text; at most one of those parameters can be provided) - - """ - - text = b'' - - def __init__( - self, - type_: str, - name: str, - port: Optional[int] = None, - weight: int = 0, - priority: int = 0, - properties: Union[bytes, Dict] = b'', - server: Optional[str] = None, - host_ttl: int = _DNS_HOST_TTL, - other_ttl: int = _DNS_OTHER_TTL, - *, - addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None - ) -> None: - # Accept both none, or one, but not both. - if addresses is not None and parsed_addresses is not None: - raise TypeError("addresses and parsed_addresses cannot be provided together") - if not type_.endswith(service_type_name(name, strict=False)): - raise BadTypeInNameException - self.type = type_ - self._name = name - self.key = name.lower() - if addresses is not None: - self._addresses = addresses - elif parsed_addresses is not None: - self._addresses = [_encode_address(a) for a in parsed_addresses] - else: - self._addresses = [] - # This results in an ugly error when registering, better check now - invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] - if invalid: - raise TypeError( - 'Addresses must be bytes, got %s. Hint: convert string addresses ' - 'with socket.inet_pton' % invalid - ) - self.port = port - self.weight = weight - self.priority = priority - self.server = server if server else name - self.server_key = self.server.lower() - self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} - if isinstance(properties, bytes): - self._set_text(properties) - else: - self._set_properties(properties) - self.host_ttl = host_ttl - self.other_ttl = other_ttl - - @property - def name(self) -> str: - """The name of the service.""" - return self._name - - @name.setter - def name(self, name: str) -> None: - """Replace the the name and reset the key.""" - self._name = name - self.key = name.lower() - - @property - def addresses(self) -> List[bytes]: - """IPv4 addresses of this service. - - Only IPv4 addresses are returned for backward compatibility. - Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to - include IPv6 addresses as well. - """ - return self.addresses_by_version(IPVersion.V4Only) - - @addresses.setter - def addresses(self, value: List[bytes]) -> None: - """Replace the addresses list. - - This replaces all currently stored addresses, both IPv4 and IPv6. - """ - self._addresses = value - - @property - def properties(self) -> Dict: - """If properties were set in the constructor this property returns the original dictionary - of type `Dict[Union[bytes, str], Any]`. - - If properties are coming from the network, after decoding a TXT record, the keys are always - bytes and the values are either bytes, if there was a value, even empty, or `None`, if there - was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. - """ - return self._properties - - def addresses_by_version(self, version: IPVersion) -> List[bytes]: - """List addresses matching IP version.""" - if version == IPVersion.V4Only: - return [addr for addr in self._addresses if not _is_v6_address(addr)] - if version == IPVersion.V6Only: - return list(filter(_is_v6_address, self._addresses)) - return self._addresses - - def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: - """List addresses in their parsed string form.""" - result = self.addresses_by_version(version) - return [ - socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) - for addr in result - ] - - def _set_properties(self, properties: Dict) -> None: - """Sets properties and text of this info from a dictionary""" - self._properties = properties - list_ = [] - result = b'' - for key, value in properties.items(): - if isinstance(key, str): - key = key.encode('utf-8') - - record = key - if value is not None: - if not isinstance(value, bytes): - value = str(value).encode('utf-8') - record += b'=' + value - list_.append(record) - for item in list_: - result = b''.join((result, int2byte(len(item)), item)) - self.text = result - - def _set_text(self, text: bytes) -> None: - """Sets properties and text given a text field""" - self.text = text - end = len(text) - if end == 0: - self._properties = {} - return - result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} - index = 0 - strs = [] - while index < end: - length = text[index] - index += 1 - strs.append(text[index : index + length]) - index += length - - key: bytes - value: Optional[bytes] - for s in strs: - try: - key, value = s.split(b'=', 1) - except ValueError: - # No equals sign at all - key = s - value = None - - # Only update non-existent properties - if key and result.get(key) is None: - result[key] = value - - self._properties = result - - def get_name(self) -> str: - """Name accessor""" - return self.name[: len(self.name) - len(self.type) - 1] - - def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: - """Updates service information from a DNS record. - - This method is deprecated and will be removed in a future version. - update_records should be implemented instead. - - This method will be run in the event loop. - """ - if record is not None: - self._process_records_threadsafe(zc, now, [record]) - - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Updates service information from a DNS record. - - This method will be run in the event loop. - """ - self._process_records_threadsafe(zc, now, records) - - def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Thread safe record updating.""" - update_addresses = False - for record in records: - if isinstance(record, DNSService): - update_addresses = True - self._process_record_threadsafe(record, now) - - # Only update addresses if the DNSService (.server) has changed - if not update_addresses: - return - - for record in self._get_address_records_from_cache(zc): - self._process_record_threadsafe(record, now) - - def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: - if record.is_expired(now): - return - - if isinstance(record, DNSAddress): - if record.key == self.server_key and record.address not in self._addresses: - self._addresses.append(record.address) - return - - if isinstance(record, DNSService): - if record.key != self.key: - return - self.name = record.name - self.server = record.server - self.server_key = record.server.lower() - self.port = record.port - self.weight = record.weight - self.priority = record.priority - return - - if isinstance(record, DNSText): - if record.key == self.key: - self._set_text(record.text) - - def dns_addresses( - self, - override_ttl: Optional[int] = None, - version: IPVersion = IPVersion.All, - created: Optional[float] = None, - ) -> List[DNSAddress]: - """Return matching DNSAddress from ServiceInfo.""" - return [ - DNSAddress( - self.server, - _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - address, - created, - ) - for address in self.addresses_by_version(version) - ] - - def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: - """Return DNSPointer from ServiceInfo.""" - return DNSPointer( - self.type, - _TYPE_PTR, - _CLASS_IN, - override_ttl if override_ttl is not None else self.other_ttl, - self.name, - created, - ) - - def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: - """Return DNSService from ServiceInfo.""" - return DNSService( - self.name, - _TYPE_SRV, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.host_ttl, - self.priority, - self.weight, - cast(int, self.port), - self.server, - created, - ) - - def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: - """Return DNSText from ServiceInfo.""" - return DNSText( - self.name, - _TYPE_TXT, - _CLASS_IN | _CLASS_UNIQUE, - override_ttl if override_ttl is not None else self.other_ttl, - self.text, - created, - ) - - def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: - """Get the address records from the cache.""" - return [ - *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), - *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), - ] - - def load_from_cache(self, zc: 'Zeroconf') -> bool: - """Populate the service info from the cache. - - This method is designed to be threadsafe. - """ - now = current_time_millis() - record_updates = [] - cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) - if cached_srv_record: - # If there is a srv record, A and AAAA will already - # be called and we do not want to do it twice - record_updates.append(cached_srv_record) - else: - record_updates.extend(self._get_address_records_from_cache(zc)) - cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) - if cached_txt_record: - record_updates.append(cached_txt_record) - self._process_records_threadsafe(zc, now, record_updates) - return self._is_complete - - @property - def _is_complete(self) -> bool: - """The ServiceInfo has all expected properties.""" - return not (self.text is None or not self._addresses) - - def request(self, zc: 'Zeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - if self.load_from_cache(zc): - return True - - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - try: - # Do not set a question on the listener to preload from cache - # since we just checked it above in load_from_cache - zc.add_listener(self, None) - while not self._is_complete: - if last <= now: - return False - if next_ <= now: - out = self.generate_request_query(zc, now) - if not out.questions: - return True - zc.send(out) - next_ = now + delay - delay *= 2 - - zc.wait(min(next_, last) - now) - now = current_time_millis() - finally: - zc.remove_listener(self) - - return True - - def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: - """Generate the request query.""" - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) - out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) - out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) - return out - - def __eq__(self, other: object) -> bool: - """Tests equality of service name""" - return isinstance(other, ServiceInfo) and other.name == self.name - - def __repr__(self) -> str: - """String representation""" - return '%s(%s)' % ( - type(self).__name__, - ', '.join( - '%s=%r' % (name, getattr(self, name)) - for name in ( - 'type', - 'name', - 'addresses', - 'port', - 'weight', - 'priority', - 'server', - 'properties', - ) - ), - ) diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py new file mode 100644 index 00000000..a3536ed1 --- /dev/null +++ b/zeroconf/_services/info.py @@ -0,0 +1,458 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import socket +from typing import Dict, List, Optional, TYPE_CHECKING, Union, cast + +from .._dns import DNSAddress, DNSPointer, DNSRecord, DNSService, DNSText +from .._exceptions import BadTypeInNameException +from .._protocol import DNSOutgoing +from .._services import RecordUpdateListener +from .._utils.name import service_type_name +from .._utils.net import ( + IPVersion, + _encode_address, + _is_v6_address, +) +from .._utils.struct import int2byte +from .._utils.time import current_time_millis +from ..const import ( + _CLASS_IN, + _CLASS_UNIQUE, + _DNS_HOST_TTL, + _DNS_OTHER_TTL, + _FLAGS_QR_QUERY, + _LISTENER_TIME, + _TYPE_A, + _TYPE_AAAA, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from .._core import Zeroconf # pylint: disable=cyclic-import + + +def instance_name_from_service_info(info: "ServiceInfo") -> str: + """Calculate the instance name from the ServiceInfo.""" + # This is kind of funky because of the subtype based tests + # need to make subtypes a first class citizen + service_name = service_type_name(info.name) + if not info.type.endswith(service_name): + raise BadTypeInNameException + return info.name[: -len(service_name) - 1] + + +class ServiceInfo(RecordUpdateListener): + """Service information. + + Constructor parameters are as follows: + + * `type_`: fully qualified service type name + * `name`: fully qualified service name + * `port`: port that the service runs on + * `weight`: weight of the service + * `priority`: priority of the service + * `properties`: dictionary of properties (or a bytes object holding the contents of the `text` field). + converted to str and then encoded to bytes using UTF-8. Keys with `None` values are converted to + value-less attributes. + * `server`: fully qualified name for service host (defaults to name) + * `host_ttl`: ttl used for A/SRV records + * `other_ttl`: ttl used for PTR/TXT records + * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, + or in parsed form as text; at most one of those parameters can be provided) + + """ + + text = b'' + + def __init__( + self, + type_: str, + name: str, + port: Optional[int] = None, + weight: int = 0, + priority: int = 0, + properties: Union[bytes, Dict] = b'', + server: Optional[str] = None, + host_ttl: int = _DNS_HOST_TTL, + other_ttl: int = _DNS_OTHER_TTL, + *, + addresses: Optional[List[bytes]] = None, + parsed_addresses: Optional[List[str]] = None + ) -> None: + # Accept both none, or one, but not both. + if addresses is not None and parsed_addresses is not None: + raise TypeError("addresses and parsed_addresses cannot be provided together") + if not type_.endswith(service_type_name(name, strict=False)): + raise BadTypeInNameException + self.type = type_ + self._name = name + self.key = name.lower() + if addresses is not None: + self._addresses = addresses + elif parsed_addresses is not None: + self._addresses = [_encode_address(a) for a in parsed_addresses] + else: + self._addresses = [] + # This results in an ugly error when registering, better check now + invalid = [a for a in self._addresses if not isinstance(a, bytes) or len(a) not in (4, 16)] + if invalid: + raise TypeError( + 'Addresses must be bytes, got %s. Hint: convert string addresses ' + 'with socket.inet_pton' % invalid + ) + self.port = port + self.weight = weight + self.priority = priority + self.server = server if server else name + self.server_key = self.server.lower() + self._properties: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + if isinstance(properties, bytes): + self._set_text(properties) + else: + self._set_properties(properties) + self.host_ttl = host_ttl + self.other_ttl = other_ttl + + @property + def name(self) -> str: + """The name of the service.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + """Replace the the name and reset the key.""" + self._name = name + self.key = name.lower() + + @property + def addresses(self) -> List[bytes]: + """IPv4 addresses of this service. + + Only IPv4 addresses are returned for backward compatibility. + Use :meth:`addresses_by_version` or :meth:`parsed_addresses` to + include IPv6 addresses as well. + """ + return self.addresses_by_version(IPVersion.V4Only) + + @addresses.setter + def addresses(self, value: List[bytes]) -> None: + """Replace the addresses list. + + This replaces all currently stored addresses, both IPv4 and IPv6. + """ + self._addresses = value + + @property + def properties(self) -> Dict: + """If properties were set in the constructor this property returns the original dictionary + of type `Dict[Union[bytes, str], Any]`. + + If properties are coming from the network, after decoding a TXT record, the keys are always + bytes and the values are either bytes, if there was a value, even empty, or `None`, if there + was none. No further decoding is attempted. The type returned is `Dict[bytes, Optional[bytes]]`. + """ + return self._properties + + def addresses_by_version(self, version: IPVersion) -> List[bytes]: + """List addresses matching IP version.""" + if version == IPVersion.V4Only: + return [addr for addr in self._addresses if not _is_v6_address(addr)] + if version == IPVersion.V6Only: + return list(filter(_is_v6_address, self._addresses)) + return self._addresses + + def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """List addresses in their parsed string form.""" + result = self.addresses_by_version(version) + return [ + socket.inet_ntop(socket.AF_INET6 if _is_v6_address(addr) else socket.AF_INET, addr) + for addr in result + ] + + def _set_properties(self, properties: Dict) -> None: + """Sets properties and text of this info from a dictionary""" + self._properties = properties + list_ = [] + result = b'' + for key, value in properties.items(): + if isinstance(key, str): + key = key.encode('utf-8') + + record = key + if value is not None: + if not isinstance(value, bytes): + value = str(value).encode('utf-8') + record += b'=' + value + list_.append(record) + for item in list_: + result = b''.join((result, int2byte(len(item)), item)) + self.text = result + + def _set_text(self, text: bytes) -> None: + """Sets properties and text given a text field""" + self.text = text + end = len(text) + if end == 0: + self._properties = {} + return + result: Dict[Union[str, bytes], Optional[Union[str, bytes]]] = {} + index = 0 + strs = [] + while index < end: + length = text[index] + index += 1 + strs.append(text[index : index + length]) + index += length + + key: bytes + value: Optional[bytes] + for s in strs: + try: + key, value = s.split(b'=', 1) + except ValueError: + # No equals sign at all + key = s + value = None + + # Only update non-existent properties + if key and result.get(key) is None: + result[key] = value + + self._properties = result + + def get_name(self) -> str: + """Name accessor""" + return self.name[: len(self.name) - len(self.type) - 1] + + def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + + This method will be run in the event loop. + """ + if record is not None: + self._process_records_threadsafe(zc, now, [record]) + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Updates service information from a DNS record. + + This method will be run in the event loop. + """ + self._process_records_threadsafe(zc, now, records) + + def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Thread safe record updating.""" + update_addresses = False + for record in records: + if isinstance(record, DNSService): + update_addresses = True + self._process_record_threadsafe(record, now) + + # Only update addresses if the DNSService (.server) has changed + if not update_addresses: + return + + for record in self._get_address_records_from_cache(zc): + self._process_record_threadsafe(record, now) + + def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: + if record.is_expired(now): + return + + if isinstance(record, DNSAddress): + if record.key == self.server_key and record.address not in self._addresses: + self._addresses.append(record.address) + return + + if isinstance(record, DNSService): + if record.key != self.key: + return + self.name = record.name + self.server = record.server + self.server_key = record.server.lower() + self.port = record.port + self.weight = record.weight + self.priority = record.priority + return + + if isinstance(record, DNSText): + if record.key == self.key: + self._set_text(record.text) + + def dns_addresses( + self, + override_ttl: Optional[int] = None, + version: IPVersion = IPVersion.All, + created: Optional[float] = None, + ) -> List[DNSAddress]: + """Return matching DNSAddress from ServiceInfo.""" + return [ + DNSAddress( + self.server, + _TYPE_AAAA if _is_v6_address(address) else _TYPE_A, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + address, + created, + ) + for address in self.addresses_by_version(version) + ] + + def dns_pointer(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSPointer: + """Return DNSPointer from ServiceInfo.""" + return DNSPointer( + self.type, + _TYPE_PTR, + _CLASS_IN, + override_ttl if override_ttl is not None else self.other_ttl, + self.name, + created, + ) + + def dns_service(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSService: + """Return DNSService from ServiceInfo.""" + return DNSService( + self.name, + _TYPE_SRV, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.host_ttl, + self.priority, + self.weight, + cast(int, self.port), + self.server, + created, + ) + + def dns_text(self, override_ttl: Optional[int] = None, created: Optional[float] = None) -> DNSText: + """Return DNSText from ServiceInfo.""" + return DNSText( + self.name, + _TYPE_TXT, + _CLASS_IN | _CLASS_UNIQUE, + override_ttl if override_ttl is not None else self.other_ttl, + self.text, + created, + ) + + def _get_address_records_from_cache(self, zc: 'Zeroconf') -> List[DNSRecord]: + """Get the address records from the cache.""" + return [ + *zc.cache.get_all_by_details(self.server, _TYPE_A, _CLASS_IN), + *zc.cache.get_all_by_details(self.server, _TYPE_AAAA, _CLASS_IN), + ] + + def load_from_cache(self, zc: 'Zeroconf') -> bool: + """Populate the service info from the cache. + + This method is designed to be threadsafe. + """ + now = current_time_millis() + record_updates = [] + cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) + if cached_srv_record: + # If there is a srv record, A and AAAA will already + # be called and we do not want to do it twice + record_updates.append(cached_srv_record) + else: + record_updates.extend(self._get_address_records_from_cache(zc)) + cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) + if cached_txt_record: + record_updates.append(cached_txt_record) + self._process_records_threadsafe(zc, now, record_updates) + return self._is_complete + + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return not (self.text is None or not self._addresses) + + def request(self, zc: 'Zeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + if self.load_from_cache(zc): + return True + + now = current_time_millis() + delay = _LISTENER_TIME + next_ = now + last = now + timeout + try: + # Do not set a question on the listener to preload from cache + # since we just checked it above in load_from_cache + zc.add_listener(self, None) + while not self._is_complete: + if last <= now: + return False + if next_ <= now: + out = self.generate_request_query(zc, now) + if not out.questions: + return True + zc.send(out) + next_ = now + delay + delay *= 2 + + zc.wait(min(next_, last) - now) + now = current_time_millis() + finally: + zc.remove_listener(self) + + return True + + def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: + """Generate the request query.""" + out = DNSOutgoing(_FLAGS_QR_QUERY) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) + out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) + out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) + return out + + def __eq__(self, other: object) -> bool: + """Tests equality of service name""" + return isinstance(other, ServiceInfo) and other.name == self.name + + def __repr__(self) -> str: + """String representation""" + return '%s(%s)' % ( + type(self).__name__, + ', '.join( + '%s=%r' % (name, getattr(self, name)) + for name in ( + 'type', + 'name', + 'addresses', + 'port', + 'weight', + 'priority', + 'server', + 'properties', + ) + ), + ) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 20584b3a..ebf5abbb 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -24,8 +24,8 @@ from typing import Dict, List, Optional, Union +from .info import ServiceInfo from .._exceptions import ServiceNameAlreadyRegistered -from .._services import ServiceInfo class ServiceRegistry: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index d5414d13..e64c87c3 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,7 +26,8 @@ from ._core import NotifyListener, Zeroconf from ._exceptions import NonUniqueNameException -from ._services import ServiceInfo, _ServiceBrowserBase, instance_name_from_service_info +from ._services import _ServiceBrowserBase +from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.types import ZeroconfServiceTypes from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType From 368163d3c30325d60021203430711e10fd6d97e9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 21:51:36 -1000 Subject: [PATCH 387/608] Relocate ServiceBrowser to zeroconf._services.browser (#744) --- tests/test_services.py | 25 +- zeroconf/__init__.py | 4 +- zeroconf/_core.py | 3 +- zeroconf/_services/__init__.py | 373 +----------------------------- zeroconf/_services/browser.py | 405 +++++++++++++++++++++++++++++++++ zeroconf/_services/types.py | 3 +- zeroconf/aio.py | 2 +- 7 files changed, 429 insertions(+), 386 deletions(-) create mode 100644 zeroconf/_services/browser.py diff --git a/tests/test_services.py b/tests/test_services.py index 2a077329..88b490fb 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -17,9 +17,10 @@ import zeroconf as r from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis -import zeroconf._services as s +import zeroconf._services.browser as _services_browser from zeroconf import Zeroconf -from zeroconf._services import ServiceBrowser, ServiceStateChange +from zeroconf._services import ServiceStateChange +from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo from zeroconf.aio import AsyncZeroconf @@ -984,7 +985,7 @@ def test_backoff(): time_offset = 0.0 start_time = time.time() * 1000 - initial_query_interval = s._BROWSER_TIME / 1000 + initial_query_interval = _services_browser._BROWSER_TIME / 1000 def current_time_millis(): """Current system time in milliseconds""" @@ -999,8 +1000,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the zeroconf current_time_millis # patch the backoff limit to prevent test running forever with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( - s, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(s, "_BROWSER_BACKOFF_LIMIT", 10): + _services_browser, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass @@ -1024,7 +1025,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): if time_offset == expected_query_time: assert got_query.is_set() got_query.clear() - if next_query_interval == s._BROWSER_BACKOFF_LIMIT: + if next_query_interval == _services_browser._BROWSER_BACKOFF_LIMIT: # Only need to test up to the point where we've seen a query # after the backoff limit has been hit break @@ -1032,7 +1033,9 @@ def on_service_state_change(zeroconf, service_type, state_change, name): next_query_interval = initial_query_interval expected_query_time = initial_query_interval else: - next_query_interval = min(2 * next_query_interval, s._BROWSER_BACKOFF_LIMIT) + next_query_interval = min( + 2 * next_query_interval, _services_browser._BROWSER_BACKOFF_LIMIT + ) expected_query_time += next_query_interval else: assert not got_query.is_set() @@ -1090,8 +1093,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the zeroconf current_time_millis # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( - s, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(s, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): + _services_browser, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): service_added = Event() service_removed = Event() @@ -1524,7 +1527,7 @@ def test_serviceinfo_accepts_bytes_or_string_dict(): def test_group_ptr_queries_with_known_answers(): - questions_with_known_answers: s._QuestionWithKnownAnswers = {} + questions_with_known_answers: _services_browser._QuestionWithKnownAnswers = {} now = current_time_millis() for i in range(120): name = f"_hap{i}._tcp._local." @@ -1538,7 +1541,7 @@ def test_group_ptr_queries_with_known_answers(): ) for counter in range(i) ) - outs = s._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) + outs = _services_browser._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) for out in outs: packets = out.packets() # If we generate multiple packets there must diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e61a7119..3715e174 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -49,10 +49,12 @@ Signal, SignalRegistrationInterface, RecordUpdateListener, - ServiceBrowser, ServiceListener, ServiceStateChange, ) +from ._services.browser import ( # noqa # import needed for backwards compat + ServiceBrowser, +) from ._services.info import ( # noqa # import needed for backwards compat instance_name_from_service_info, ServiceInfo, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 675d7169..ff54dc7e 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -37,7 +37,8 @@ from ._handlers import QueryHandler, RecordManager from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing -from ._services import RecordUpdateListener, ServiceBrowser, ServiceListener +from ._services import RecordUpdateListener, ServiceListener +from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._utils.aio import get_running_loop diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 111ea448..776d43a7 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -21,39 +21,15 @@ """ import enum -import threading -import warnings -from collections import OrderedDict -from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast - -from .._cache import _UniqueRecordsType -from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord -from .._protocol import DNSOutgoing -from .._utils.name import service_type_name -from .._utils.time import current_time_millis, millis_to_seconds -from ..const import ( - _BROWSER_BACKOFF_LIMIT, - _BROWSER_TIME, - _CLASS_IN, - _DNS_PACKET_HEADER_LEN, - _EXPIRE_REFRESH_TIME_PERCENT, - _FLAGS_QR_QUERY, - _MAX_MSG_TYPICAL, - _MDNS_ADDR, - _MDNS_ADDR6, - _MDNS_PORT, - _TYPE_PTR, -) +from typing import Any, Callable, List, TYPE_CHECKING +from .._dns import DNSRecord if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 from .._core import Zeroconf # pylint: disable=cyclic-import -_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] - - @enum.unique class ServiceStateChange(enum.Enum): Added = 1 @@ -134,348 +110,3 @@ def async_update_records_complete(self) -> None: This method will be run in the event loop. """ - - -class _DNSPointerOutgoingBucket: - """A DNSOutgoing bucket.""" - - def __init__(self, now: float, multicast: bool) -> None: - """Create a bucke to wrap a DNSOutgoing.""" - self.now = now - self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast) - self.bytes = 0 - - def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None: - """Add a new set of questions and known answers to the outgoing.""" - self.out.add_question(question) - for answer in answers: - self.out.add_answer_at_time(answer, self.now) - self.bytes += max_compressed_size - - -def _group_ptr_queries_with_known_answers( - now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers -) -> List[DNSOutgoing]: - """Aggregate queries so that as many known answers as possible fit in the same packet - without having known answers spill over into the next packet unless the - question and known answers are always going to exceed the packet size. - - Some responders do not implement multi-packet known answer suppression - so we try to keep all the known answers in the same packet as the - questions. - """ - # This is the maximum size the query + known answers can be with name compression. - # The actual size of the query + known answers may be a bit smaller since other - # parts may be shared when the final DNSOutgoing packets are constructed. The - # goal of this algorithm is to quickly bucket the query + known answers without - # the overhead of actually constructing the packets. - query_by_size: Dict[DNSQuestion, int] = { - question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers])) - for question, known_answers in question_with_known_answers.items() - } - max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN - query_buckets: List[_DNSPointerOutgoingBucket] = [] - for question in sorted( - query_by_size, - key=query_by_size.get, # type: ignore - reverse=True, - ): - max_compressed_size = query_by_size[question] - answers = question_with_known_answers[question] - for query_bucket in query_buckets: - if query_bucket.bytes + max_compressed_size <= max_bucket_size: - query_bucket.add(max_compressed_size, question, answers) - break - else: - # If a single question and known answers won't fit in a packet - # we will end up generating multiple packets, but there will never - # be multiple questions - query_bucket = _DNSPointerOutgoingBucket(now, multicast) - query_bucket.add(max_compressed_size, question, answers) - query_buckets.append(query_bucket) - - return [query_bucket.out for query_bucket in query_buckets] - - -def generate_service_query( - zc: 'Zeroconf', now: float, types_: List[str], multicast: bool = True -) -> List[DNSOutgoing]: - """Generate a service query for sending with zeroconf.send.""" - questions_with_known_answers: _QuestionWithKnownAnswers = {} - for type_ in types_: - question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) - questions_with_known_answers[question] = set( - cast(DNSPointer, record) - for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) - if not record.is_stale(now) - ) - return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) - - -def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: - """Generate a service_state_changed handlers from a listener.""" - - def on_change( - zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange - ) -> None: - assert listener is not None - args = (zeroconf, service_type, name) - if state_change is ServiceStateChange.Added: - listener.add_service(*args) - elif state_change is ServiceStateChange.Removed: - listener.remove_service(*args) - elif state_change is ServiceStateChange.Updated: - if hasattr(listener, 'update_service'): - listener.update_service(*args) - else: - warnings.warn( - "%r has no update_service method. Provide one (it can be empty if you " - "don't care about the updates), it'll become mandatory." % (listener,), - FutureWarning, - ) - - return on_change - - -class _ServiceBrowserBase(RecordUpdateListener): - """Base class for ServiceBrowser.""" - - def __init__( - self, - zc: 'Zeroconf', - type_: Union[str, list], - handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, - listener: Optional['ServiceListener'] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - """Creates a browser for a specific type""" - assert handlers or listener, 'You need to specify at least one handler' - self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) - for check_type_ in self.types: - # Will generate BadTypeInNameException on a bad name - service_type_name(check_type_, strict=False) - self.zc = zc - self.addr = addr - self.port = port - self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - current_time = current_time_millis() - self._next_time = {check_type_: current_time for check_type_ in self.types} - self._delay = {check_type_: delay for check_type_ in self.types} - self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() - self._handlers_to_call: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() - self._service_state_changed = Signal() - - self.done = False - - if hasattr(handlers, 'add_service'): - listener = cast('ServiceListener', handlers) - handlers = None - - handlers = cast(List[Callable[..., None]], handlers or []) - - if listener: - handlers.append(_service_state_changed_from_listener(listener)) - - for h in handlers: - self.service_state_changed.register_handler(h) - - self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) - - @property - def service_state_changed(self) -> SignalRegistrationInterface: - return self._service_state_changed.registration_interface - - def _record_matching_type(self, record: DNSRecord) -> Optional[str]: - """Return the type if the record matches one of the types we are browsing.""" - return next((type_ for type_ in self.types if record.name.endswith(type_)), None) - - def _enqueue_callback( - self, - state_change: ServiceStateChange, - type_: str, - name: str, - ) -> None: - # Code to ensure we only do a single update message - # Precedence is; Added, Remove, Update - key = (name, type_) - if ( - state_change is ServiceStateChange.Added - or ( - state_change is ServiceStateChange.Removed - and self._pending_handlers.get(key) != ServiceStateChange.Added - ) - or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) - ): - self._pending_handlers[key] = state_change - - def _async_process_record_update(self, now: float, record: DNSRecord) -> None: - """Process a single record update from a batch of updates.""" - expired = record.is_expired(now) - - if isinstance(record, DNSPointer): - if record.name not in self.types: - return - old_record = self.zc.cache.async_get_unique( - DNSPointer(record.name, _TYPE_PTR, _CLASS_IN, 0, record.alias) - ) - if old_record is None: - self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) - elif expired: - self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) - else: - expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) - if expires < self._next_time[record.name]: - self._next_time[record.name] = expires - return - - # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): - return - - if isinstance(record, DNSAddress): - # Iterate through the DNSCache and callback any services that use this address - for service in self.zc.cache.async_entries_with_server(record.name): - type_ = self._record_matching_type(service) - if type_: - self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) - break - - return - - type_ = self._record_matching_type(record) - if type_: - self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) - - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Callback invoked by Zeroconf when new information arrives. - - Updates information required by browser in the Zeroconf cache. - - Ensures that there is are no unecessary duplicates in the list. - - This method will be run in the event loop. - """ - for record in records: - self._async_process_record_update(now, record) - - def async_update_records_complete(self) -> None: - """Called when a record update has completed for all handlers. - - At this point the cache will have the new records. - - This method will be run in the event loop. - """ - # Cannot use .update here since can fail with - # RuntimeError: dictionary changed size during iteration - # for threaded ServiceBrowsers - while self._pending_handlers: - try: - (name_type, state_change) = self._pending_handlers.popitem(False) - except KeyError: - return - self._handlers_to_call[name_type] = state_change - - def cancel(self) -> None: - """Cancel the browser.""" - self.done = True - self.zc.remove_listener(self) - - def generate_ready_queries(self) -> List[DNSOutgoing]: - """Generate the service browser query for any type that is due.""" - now = current_time_millis() - - if min(self._next_time.values()) > now: - return [] - - ready_types = [] - - for type_, due in self._next_time.items(): - if due > now: - continue - - ready_types.append(type_) - self._next_time[type_] = now + self._delay[type_] - self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - - return generate_service_query(self.zc, now, ready_types, self.multicast) - - def _seconds_to_wait(self) -> Optional[float]: - """Returns the number of seconds to wait for the next event.""" - # If there are handlers to call - # we want to process them right away - if self._handlers_to_call: - return None - - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - now = current_time_millis() - - if next_time <= now: - return None - - return millis_to_seconds(next_time - now) - - -class ServiceBrowser(_ServiceBrowserBase, threading.Thread): - """Used to browse for a service of a specific type. - - The listener object will have its add_service() and - remove_service() methods called when this browser - discovers changes in the services availability.""" - - def __init__( - self, - zc: 'Zeroconf', - type_: Union[str, list], - handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, - listener: Optional['ServiceListener'] = None, - addr: Optional[str] = None, - port: int = _MDNS_PORT, - delay: int = _BROWSER_TIME, - ) -> None: - threading.Thread.__init__(self) - super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) - self.daemon = True - self.start() - self.name = "zeroconf-ServiceBrowser-%s-%s" % ( - '-'.join([type_[:-7] for type_ in self.types]), - getattr(self, 'native_id', self.ident), - ) - - def cancel(self) -> None: - """Cancel the browser.""" - super().cancel() - self.join() - - def run(self) -> None: - """Run the browser thread.""" - while True: - timeout = self._seconds_to_wait() - if timeout: - with self.zc.condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - self.zc.condition.wait(timeout) - - if self.zc.done or self.done: - return - - outs = self.generate_ready_queries() - for out in outs: - self.zc.send(out, addr=self.addr, port=self.port) - - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py new file mode 100644 index 00000000..b633df67 --- /dev/null +++ b/zeroconf/_services/browser.py @@ -0,0 +1,405 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import threading +import warnings +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast + +from .._cache import _UniqueRecordsType +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord +from .._protocol import DNSOutgoing +from .._services import ( + RecordUpdateListener, + ServiceListener, + ServiceStateChange, + Signal, + SignalRegistrationInterface, +) +from .._utils.name import service_type_name +from .._utils.time import current_time_millis, millis_to_seconds +from ..const import ( + _BROWSER_BACKOFF_LIMIT, + _BROWSER_TIME, + _CLASS_IN, + _DNS_PACKET_HEADER_LEN, + _EXPIRE_REFRESH_TIME_PERCENT, + _FLAGS_QR_QUERY, + _MAX_MSG_TYPICAL, + _MDNS_ADDR, + _MDNS_ADDR6, + _MDNS_PORT, + _TYPE_PTR, +) + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from .._core import Zeroconf # pylint: disable=cyclic-import + + +_QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] + + +class _DNSPointerOutgoingBucket: + """A DNSOutgoing bucket.""" + + def __init__(self, now: float, multicast: bool) -> None: + """Create a bucke to wrap a DNSOutgoing.""" + self.now = now + self.out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=multicast) + self.bytes = 0 + + def add(self, max_compressed_size: int, question: DNSQuestion, answers: Set[DNSPointer]) -> None: + """Add a new set of questions and known answers to the outgoing.""" + self.out.add_question(question) + for answer in answers: + self.out.add_answer_at_time(answer, self.now) + self.bytes += max_compressed_size + + +def _group_ptr_queries_with_known_answers( + now: float, multicast: bool, question_with_known_answers: _QuestionWithKnownAnswers +) -> List[DNSOutgoing]: + """Aggregate queries so that as many known answers as possible fit in the same packet + without having known answers spill over into the next packet unless the + question and known answers are always going to exceed the packet size. + + Some responders do not implement multi-packet known answer suppression + so we try to keep all the known answers in the same packet as the + questions. + """ + # This is the maximum size the query + known answers can be with name compression. + # The actual size of the query + known answers may be a bit smaller since other + # parts may be shared when the final DNSOutgoing packets are constructed. The + # goal of this algorithm is to quickly bucket the query + known answers without + # the overhead of actually constructing the packets. + query_by_size: Dict[DNSQuestion, int] = { + question: (question.max_size + sum([answer.max_size_compressed for answer in known_answers])) + for question, known_answers in question_with_known_answers.items() + } + max_bucket_size = _MAX_MSG_TYPICAL - _DNS_PACKET_HEADER_LEN + query_buckets: List[_DNSPointerOutgoingBucket] = [] + for question in sorted( + query_by_size, + key=query_by_size.get, # type: ignore + reverse=True, + ): + max_compressed_size = query_by_size[question] + answers = question_with_known_answers[question] + for query_bucket in query_buckets: + if query_bucket.bytes + max_compressed_size <= max_bucket_size: + query_bucket.add(max_compressed_size, question, answers) + break + else: + # If a single question and known answers won't fit in a packet + # we will end up generating multiple packets, but there will never + # be multiple questions + query_bucket = _DNSPointerOutgoingBucket(now, multicast) + query_bucket.add(max_compressed_size, question, answers) + query_buckets.append(query_bucket) + + return [query_bucket.out for query_bucket in query_buckets] + + +def generate_service_query( + zc: 'Zeroconf', now: float, types_: List[str], multicast: bool = True +) -> List[DNSOutgoing]: + """Generate a service query for sending with zeroconf.send.""" + questions_with_known_answers: _QuestionWithKnownAnswers = {} + for type_ in types_: + question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) + questions_with_known_answers[question] = set( + cast(DNSPointer, record) + for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) + if not record.is_stale(now) + ) + return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) + + +def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: + """Generate a service_state_changed handlers from a listener.""" + + def on_change( + zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange + ) -> None: + assert listener is not None + args = (zeroconf, service_type, name) + if state_change is ServiceStateChange.Added: + listener.add_service(*args) + elif state_change is ServiceStateChange.Removed: + listener.remove_service(*args) + elif state_change is ServiceStateChange.Updated: + if hasattr(listener, 'update_service'): + listener.update_service(*args) + else: + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) + + return on_change + + +class _ServiceBrowserBase(RecordUpdateListener): + """Base class for ServiceBrowser.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, + listener: Optional['ServiceListener'] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + """Creates a browser for a specific type""" + assert handlers or listener, 'You need to specify at least one handler' + self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) + for check_type_ in self.types: + # Will generate BadTypeInNameException on a bad name + service_type_name(check_type_, strict=False) + self.zc = zc + self.addr = addr + self.port = port + self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) + current_time = current_time_millis() + self._next_time = {check_type_: current_time for check_type_ in self.types} + self._delay = {check_type_: delay for check_type_ in self.types} + self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() + self._handlers_to_call: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() + self._service_state_changed = Signal() + + self.done = False + + if hasattr(handlers, 'add_service'): + listener = cast('ServiceListener', handlers) + handlers = None + + handlers = cast(List[Callable[..., None]], handlers or []) + + if listener: + handlers.append(_service_state_changed_from_listener(listener)) + + for h in handlers: + self.service_state_changed.register_handler(h) + + self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + + @property + def service_state_changed(self) -> SignalRegistrationInterface: + return self._service_state_changed.registration_interface + + def _record_matching_type(self, record: DNSRecord) -> Optional[str]: + """Return the type if the record matches one of the types we are browsing.""" + return next((type_ for type_ in self.types if record.name.endswith(type_)), None) + + def _enqueue_callback( + self, + state_change: ServiceStateChange, + type_: str, + name: str, + ) -> None: + # Code to ensure we only do a single update message + # Precedence is; Added, Remove, Update + key = (name, type_) + if ( + state_change is ServiceStateChange.Added + or ( + state_change is ServiceStateChange.Removed + and self._pending_handlers.get(key) != ServiceStateChange.Added + ) + or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers) + ): + self._pending_handlers[key] = state_change + + def _async_process_record_update(self, now: float, record: DNSRecord) -> None: + """Process a single record update from a batch of updates.""" + expired = record.is_expired(now) + + if isinstance(record, DNSPointer): + if record.name not in self.types: + return + old_record = self.zc.cache.async_get_unique( + DNSPointer(record.name, _TYPE_PTR, _CLASS_IN, 0, record.alias) + ) + if old_record is None: + self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) + elif expired: + self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) + else: + expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) + if expires < self._next_time[record.name]: + self._next_time[record.name] = expires + return + + # If its expired or already exists in the cache it cannot be updated. + if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): + return + + if isinstance(record, DNSAddress): + # Iterate through the DNSCache and callback any services that use this address + for service in self.zc.cache.async_entries_with_server(record.name): + type_ = self._record_matching_type(service) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, service.name) + break + + return + + type_ = self._record_matching_type(record) + if type_: + self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + """Callback invoked by Zeroconf when new information arrives. + + Updates information required by browser in the Zeroconf cache. + + Ensures that there is are no unecessary duplicates in the list. + + This method will be run in the event loop. + """ + for record in records: + self._async_process_record_update(now, record) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ + # Cannot use .update here since can fail with + # RuntimeError: dictionary changed size during iteration + # for threaded ServiceBrowsers + while self._pending_handlers: + try: + (name_type, state_change) = self._pending_handlers.popitem(False) + except KeyError: + return + self._handlers_to_call[name_type] = state_change + + def cancel(self) -> None: + """Cancel the browser.""" + self.done = True + self.zc.remove_listener(self) + + def generate_ready_queries(self) -> List[DNSOutgoing]: + """Generate the service browser query for any type that is due.""" + now = current_time_millis() + + if min(self._next_time.values()) > now: + return [] + + ready_types = [] + + for type_, due in self._next_time.items(): + if due > now: + continue + + ready_types.append(type_) + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + + return generate_service_query(self.zc, now, ready_types, self.multicast) + + def _seconds_to_wait(self) -> Optional[float]: + """Returns the number of seconds to wait for the next event.""" + # If there are handlers to call + # we want to process them right away + if self._handlers_to_call: + return None + + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + now = current_time_millis() + + if next_time <= now: + return None + + return millis_to_seconds(next_time - now) + + +class ServiceBrowser(_ServiceBrowserBase, threading.Thread): + """Used to browse for a service of a specific type. + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability.""" + + def __init__( + self, + zc: 'Zeroconf', + type_: Union[str, list], + handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, + listener: Optional['ServiceListener'] = None, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + delay: int = _BROWSER_TIME, + ) -> None: + threading.Thread.__init__(self) + super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + self.daemon = True + self.start() + self.name = "zeroconf-ServiceBrowser-%s-%s" % ( + '-'.join([type_[:-7] for type_ in self.types]), + getattr(self, 'native_id', self.ident), + ) + + def cancel(self) -> None: + """Cancel the browser.""" + super().cancel() + self.join() + + def run(self) -> None: + """Run the browser thread.""" + while True: + timeout = self._seconds_to_wait() + if timeout: + with self.zc.condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + self.zc.condition.wait(timeout) + + if self.zc.done or self.done: + return + + outs = self.generate_ready_queries() + for out in outs: + self.zc.send(out, addr=self.addr, port=self.port) + + if not self._handlers_to_call: + continue + + (name_type, state_change) = self._handlers_to_call.popitem(False) + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) diff --git a/zeroconf/_services/types.py b/zeroconf/_services/types.py index f611fc4c..34b000f1 100644 --- a/zeroconf/_services/types.py +++ b/zeroconf/_services/types.py @@ -23,8 +23,9 @@ import time from typing import Optional, Set, Tuple, Union +from .browser import ServiceBrowser from .._core import Zeroconf -from .._services import ServiceBrowser, ServiceListener +from .._services import ServiceListener from .._utils.net import IPVersion, InterfaceChoice, InterfacesType from ..const import _SERVICE_TYPE_ENUMERATION_NAME diff --git a/zeroconf/aio.py b/zeroconf/aio.py index e64c87c3..00d42823 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,7 +26,7 @@ from ._core import NotifyListener, Zeroconf from ._exceptions import NonUniqueNameException -from ._services import _ServiceBrowserBase +from ._services.browser import _ServiceBrowserBase from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.types import ZeroconfServiceTypes from ._utils.aio import wait_condition_or_timeout From 869c95a51e228131eb7debe1acc47c105b9bf7b5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 22:04:40 -1000 Subject: [PATCH 388/608] Relocate service browser tests to tests/services/test_browser.py (#745) --- tests/services/test_browser.py | 776 +++++++++++++++++++++++++++++++++ tests/test_services.py | 681 +---------------------------- 2 files changed, 777 insertions(+), 680 deletions(-) create mode 100644 tests/services/test_browser.py diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py new file mode 100644 index 00000000..ccdb312f --- /dev/null +++ b/tests/services/test_browser.py @@ -0,0 +1,776 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._services.browser. """ + +import logging +import socket +import time +import os +import unittest +from threading import Event + +import pytest + +import zeroconf as r +from zeroconf import DNSPointer, DNSQuestion, const, current_time_millis +import zeroconf._services.browser as _services_browser +from zeroconf import Zeroconf +from zeroconf._services import ServiceStateChange +from zeroconf._services.browser import ServiceBrowser +from zeroconf._services.info import ServiceInfo + +from .. import has_working_ipv6, _inject_response + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestServiceBrowser(unittest.TestCase): + def test_update_record(self): + enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + service_v6_address = "2001:db8::1" + service_v6_second_address = "6001:db8::1" + + service_added_count = 0 + service_removed_count = 0 + service_updated_count = 0 + service_add_event = Event() + service_removed_event = Event() + service_updated_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + service_removed_event.set() + + def update_service(self, zc, type_, name) -> None: + nonlocal service_updated_count + service_updated_count += 1 + service_info = zc.get_service_info(type_, name) + assert socket.inet_aton(service_address) in service_info.addresses + if enable_ipv6: + assert socket.inet_pton( + socket.AF_INET6, service_v6_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert socket.inet_pton( + socket.AF_INET6, service_v6_second_address + ) in service_info.addresses_by_version(r.IPVersion.V6Only) + assert service_info.text == service_text + assert service_info.server == service_server + service_updated_event.set() + + def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + assert generated.is_response() is True + + if service_state_change == r.ServiceStateChange.Removed: + ttl = 0 + else: + ttl = 120 + + generated.add_answer_at_time( + r.DNSText( + service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + 0, + ) + + # Send the IPv6 address first since we previously + # had a bug where the IPv4 would be missing if the + # IPv6 was seen first + if enable_ipv6: + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_v6_second_address), + ), + 0, + ) + generated.add_answer_at_time( + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_aton(service_address), + ), + 0, + ) + + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + + return r.DNSIncoming(generated.packets()[0]) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) + + try: + wait_time = 3 + + # service added + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) + service_add_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 0 + assert service_removed_count == 0 + + # service SRV updated + service_updated_event.clear() + service_server = 'ash-2.local.' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 1 + assert service_removed_count == 0 + + # service TXT updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service TXT updated - duplicate update should not trigger another service_updated + service_updated_event.clear() + service_text = b'path=/~matt2/' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 2 + assert service_removed_count == 0 + + # service A updated + service_updated_event.clear() + service_address = '10.0.1.3' + # Verify we match on uppercase + service_server = service_server.upper() + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 3 + assert service_removed_count == 0 + + # service all updated + service_updated_event.clear() + service_server = 'ash-3.local.' + service_text = b'path=/~matt3/' + service_address = '10.0.1.3' + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) + service_updated_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 0 + + # service removed + _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) + service_removed_event.wait(wait_time) + assert service_added_count == 1 + assert service_updated_count == 4 + assert service_removed_count == 1 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + +class TestServiceBrowserMultipleTypes(unittest.TestCase): + def test_update_record(self): + + service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] + service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] + + service_added_count = 0 + service_removed_count = 0 + service_add_event = Event() + service_removed_event = Event() + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal service_added_count + service_added_count += 1 + if service_added_count == 3: + service_add_event.set() + + def remove_service(self, zc, type_, name) -> None: + nonlocal service_removed_count + service_removed_count += 1 + if service_removed_count == 3: + service_removed_event.set() + + def mock_incoming_msg( + service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int + ) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 + ) + return r.DNSIncoming(generated.packets()[0]) + + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) + + try: + wait_time = 3 + + # all three services added + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), + ) + zeroconf.wait(100) + + called_with_refresh_time_check = False + + def _mock_get_expiration_time(self, percent): + nonlocal called_with_refresh_time_check + if percent == const._EXPIRE_REFRESH_TIME_PERCENT: + called_with_refresh_time_check = True + return 0 + return self.created + (percent * self.ttl * 10) + + # Set an expire time that will force a refresh + with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), + ) + # Add the last record after updating the first one + # to ensure the service_add_event only gets set + # after the update + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), + ) + service_add_event.wait(wait_time) + assert called_with_refresh_time_check is True + assert service_added_count == 3 + assert service_removed_count == 0 + + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0), + ) + + # all three services removed + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), + ) + _inject_response( + zeroconf, + mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), + ) + service_removed_event.wait(wait_time) + assert service_added_count == 3 + assert service_removed_count == 3 + + finally: + assert len(zeroconf.listeners) == 1 + service_browser.cancel() + assert len(zeroconf.listeners) == 0 + zeroconf.remove_all_service_listeners() + zeroconf.close() + + +def test_backoff(): + got_query = Event() + + type_ = "_http._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.send + + time_offset = 0.0 + start_time = time.time() * 1000 + initial_query_interval = _services_browser._BROWSER_TIME / 1000 + + def current_time_millis(): + """Current system time in milliseconds""" + return start_time + time_offset * 1000 + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + got_query.set() + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to prevent test running forever + with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + _services_browser, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + try: + # Test that queries are sent at increasing intervals + sleep_count = 0 + next_query_interval = 0.0 + expected_query_time = 0.0 + while True: + sleep_count += 1 + for _ in range(2): + # If the browser thread is starting up + # its possible we notify before the initial sleep + # which means the test will fail so we need to d + # this twice to eliminate the race condition + zeroconf_browser.notify_all() + got_query.wait(0.05) + if time_offset == expected_query_time: + assert got_query.is_set() + got_query.clear() + if next_query_interval == _services_browser._BROWSER_BACKOFF_LIMIT: + # Only need to test up to the point where we've seen a query + # after the backoff limit has been hit + break + elif next_query_interval == 0: + next_query_interval = initial_query_interval + expected_query_time = initial_query_interval + else: + next_query_interval = min( + 2 * next_query_interval, _services_browser._BROWSER_BACKOFF_LIMIT + ) + expected_query_time += next_query_interval + else: + assert not got_query.is_set() + time_offset += initial_query_interval + + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_integration(): + service_added = Event() + service_removed = Event() + unexpected_ttl = Event() + got_query = Event() + + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + def on_service_state_change(zeroconf, service_type, state_change, name): + if name == registration_name: + if state_change is ServiceStateChange.Added: + service_added.set() + elif state_change is ServiceStateChange.Removed: + service_removed.set() + + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check packet sizes + old_send = zeroconf_browser.send + + time_offset = 0.0 + + def current_time_millis(): + """Current system time in milliseconds""" + return time.time() * 1000 + time_offset * 1000 + + expected_ttl = const._DNS_HOST_TTL + + nbr_answers = 0 + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + pout = r.DNSIncoming(out.packets()[0]) + nonlocal nbr_answers + for answer in pout.answers: + nbr_answers += 1 + if not answer.ttl > expected_ttl / 2: + unexpected_ttl.set() + + got_query.set() + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL + with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + _services_browser, "current_time_millis", current_time_millis + ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): + service_added = Event() + service_removed = Event() + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zeroconf_registrar.register_service(info) + + try: + service_added.wait(1) + assert service_added.is_set() + + # Test that we receive queries containing answers only if the remaining TTL + # is greater than half the original TTL + sleep_count = 0 + test_iterations = 50 + while nbr_answers < test_iterations: + # Increase simulated time shift by 1/4 of the TTL in seconds + time_offset += expected_ttl / 4 + zeroconf_browser.notify_all() + sleep_count += 1 + got_query.wait(0.1) + got_query.clear() + # Prevent the test running indefinitely in an error condition + assert sleep_count < test_iterations * 4 + assert not unexpected_ttl.is_set() + + # Don't remove service, allow close() to cleanup + + finally: + zeroconf_registrar.close() + service_removed.wait(1) + assert service_removed.is_set() + browser.cancel() + zeroconf_browser.close() + + +def test_legacy_record_update_listener(): + """Test a RecordUpdateListener that does not implement update_records.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(RuntimeError): + r.RecordUpdateListener().update_record( + zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + ) + + updates = [] + + class LegacyRecordUpdateListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: + nonlocal updates + updates.append(record) + + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + zc.wait(1) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + + zc.close() + + +def test_service_browser_is_aware_of_port_changes(): + """Test that the ServiceBrowser is aware of port changes.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + callbacks = [] + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + nonlocal callbacks + if name == registration_name: + callbacks.append((service_type, state_change, name)) + + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.1) + + assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')] + assert zc.get_service_info(type_, registration_name).port == 80 + + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.1) + + assert callbacks == [ + ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'), + ('_hap._tcp.local.', ServiceStateChange.Updated, 'xxxyyy._hap._tcp.local.'), + ] + assert zc.get_service_info(type_, registration_name).port == 400 + browser.cancel() + + zc.close() + + +def test_service_browser_listeners_update_service(): + """Test that the ServiceBrowser ServiceListener that implements update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(r.ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ('update', type_, registration_name), + ] + browser.cancel() + + zc.close() + + +def test_service_browser_listeners_no_update_service(): + """Test that the ServiceBrowser ServiceListener that does not implement update_service.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener: + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + + def mock_incoming_msg(records) -> r.DNSIncoming: + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + for record in records: + generated.add_answer_at_time(record, 0) + return r.DNSIncoming(generated.packets()[0]) + + _inject_response( + zc, + mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), + ) + time.sleep(0.2) + info.port = 400 + _inject_response( + zc, + mock_incoming_msg([info.dns_service()]), + ) + time.sleep(0.2) + + assert callbacks == [ + ('add', type_, registration_name), + ] + browser.cancel() + + zc.close() + + +def test_servicebrowser_uses_non_strict_names(): + """Verify we can look for technically invalid names as we cannot change what others do.""" + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + browser = ServiceBrowser(zc, ["_tivo-videostream._tcp.local."], [on_service_state_change]) + browser.cancel() + + # Still fail on completely invalid + with pytest.raises(r.BadTypeInNameException): + browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change]) + zc.close() + + +def test_group_ptr_queries_with_known_answers(): + questions_with_known_answers: _services_browser._QuestionWithKnownAnswers = {} + now = current_time_millis() + for i in range(120): + name = f"_hap{i}._tcp._local." + questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set( + DNSPointer( + name, + const._TYPE_PTR, + const._CLASS_IN, + 4500, + f"zoo{counter}.{name}", + ) + for counter in range(i) + ) + outs = _services_browser._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) + for out in outs: + packets = out.packets() + # If we generate multiple packets there must + # only be one question + assert len(packets) == 1 or len(out.questions) == 1 diff --git a/tests/test_services.py b/tests/test_services.py index 88b490fb..a22d6f6b 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -16,10 +16,8 @@ import pytest import zeroconf as r -from zeroconf import DNSAddress, DNSPointer, DNSQuestion, const, current_time_millis -import zeroconf._services.browser as _services_browser +from zeroconf import DNSAddress, const from zeroconf import Zeroconf -from zeroconf._services import ServiceStateChange from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo from zeroconf.aio import AsyncZeroconf @@ -436,113 +434,6 @@ def get_service_info_helper(zc, type, name): zc.close() -class TestServiceBrowserMultipleTypes(unittest.TestCase): - def test_update_record(self): - - service_names = ['name2._type2._tcp.local.', 'name._type._tcp.local.', 'name._type._udp.local'] - service_types = ['_type2._tcp.local.', '_type._tcp.local.', '_type._udp.local.'] - - service_added_count = 0 - service_removed_count = 0 - service_add_event = Event() - service_removed_event = Event() - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal service_added_count - service_added_count += 1 - if service_added_count == 3: - service_add_event.set() - - def remove_service(self, zc, type_, name) -> None: - nonlocal service_removed_count - service_removed_count += 1 - if service_removed_count == 3: - service_removed_event.set() - - def mock_incoming_msg( - service_state_change: r.ServiceStateChange, service_type: str, service_name: str, ttl: int - ) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - generated.add_answer_at_time( - r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 - ) - return r.DNSIncoming(generated.packets()[0]) - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener()) - - try: - wait_time = 3 - - # all three services added - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), - ) - zeroconf.wait(100) - - called_with_refresh_time_check = False - - def _mock_get_expiration_time(self, percent): - nonlocal called_with_refresh_time_check - if percent == const._EXPIRE_REFRESH_TIME_PERCENT: - called_with_refresh_time_check = True - return 0 - return self.created + (percent * self.ttl * 10) - - # Set an expire time that will force a refresh - with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), - ) - # Add the last record after updating the first one - # to ensure the service_add_event only gets set - # after the update - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Added, service_types[2], service_names[2], 120), - ) - service_add_event.wait(wait_time) - assert called_with_refresh_time_check is True - assert service_added_count == 3 - assert service_removed_count == 0 - - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Updated, service_types[0], service_names[0], 0), - ) - - # all three services removed - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0], 0), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1], 0), - ) - _inject_response( - zeroconf, - mock_incoming_msg(r.ServiceStateChange.Removed, service_types[2], service_names[2], 0), - ) - service_removed_event.wait(wait_time) - assert service_added_count == 3 - assert service_removed_count == 3 - - finally: - assert len(zeroconf.listeners) == 1 - service_browser.cancel() - assert len(zeroconf.listeners) == 0 - zeroconf.remove_all_service_listeners() - zeroconf.close() - - class ListenerTest(unittest.TestCase): def test_integration_with_listener_class(self): @@ -699,201 +590,6 @@ def update_service(self, zeroconf, type, name): zeroconf_browser.close() -class TestServiceBrowser(unittest.TestCase): - def test_update_record(self): - enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - service_v6_address = "2001:db8::1" - service_v6_second_address = "6001:db8::1" - - service_added_count = 0 - service_removed_count = 0 - service_updated_count = 0 - service_add_event = Event() - service_removed_event = Event() - service_updated_event = Event() - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal service_added_count - service_added_count += 1 - service_add_event.set() - - def remove_service(self, zc, type_, name) -> None: - nonlocal service_removed_count - service_removed_count += 1 - service_removed_event.set() - - def update_service(self, zc, type_, name) -> None: - nonlocal service_updated_count - service_updated_count += 1 - service_info = zc.get_service_info(type_, name) - assert socket.inet_aton(service_address) in service_info.addresses - if enable_ipv6: - assert socket.inet_pton( - socket.AF_INET6, service_v6_address - ) in service_info.addresses_by_version(r.IPVersion.V6Only) - assert socket.inet_pton( - socket.AF_INET6, service_v6_second_address - ) in service_info.addresses_by_version(r.IPVersion.V6Only) - assert service_info.text == service_text - assert service_info.server == service_server - service_updated_event.set() - - def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncoming: - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - assert generated.is_response() is True - - if service_state_change == r.ServiceStateChange.Removed: - ttl = 0 - else: - ttl = 120 - - generated.add_answer_at_time( - r.DNSText( - service_name, const._TYPE_TXT, const._CLASS_IN | const._CLASS_UNIQUE, ttl, service_text - ), - 0, - ) - - generated.add_answer_at_time( - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ), - 0, - ) - - # Send the IPv6 address first since we previously - # had a bug where the IPv4 would be missing if the - # IPv6 was seen first - if enable_ipv6: - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_AAAA, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET6, service_v6_address), - ), - 0, - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_AAAA, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET6, service_v6_second_address), - ), - 0, - ) - generated.add_answer_at_time( - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_aton(service_address), - ), - 0, - ) - - generated.add_answer_at_time( - r.DNSPointer(service_type, const._TYPE_PTR, const._CLASS_IN, ttl, service_name), 0 - ) - - return r.DNSIncoming(generated.packets()[0]) - - zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) - service_browser = r.ServiceBrowser(zeroconf, service_type, listener=MyServiceListener()) - - try: - wait_time = 3 - - # service added - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Added)) - service_add_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 0 - assert service_removed_count == 0 - - # service SRV updated - service_updated_event.clear() - service_server = 'ash-2.local.' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 1 - assert service_removed_count == 0 - - # service TXT updated - service_updated_event.clear() - service_text = b'path=/~matt2/' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 2 - assert service_removed_count == 0 - - # service TXT updated - duplicate update should not trigger another service_updated - service_updated_event.clear() - service_text = b'path=/~matt2/' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 2 - assert service_removed_count == 0 - - # service A updated - service_updated_event.clear() - service_address = '10.0.1.3' - # Verify we match on uppercase - service_server = service_server.upper() - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 3 - assert service_removed_count == 0 - - # service all updated - service_updated_event.clear() - service_server = 'ash-3.local.' - service_text = b'path=/~matt3/' - service_address = '10.0.1.3' - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Updated)) - service_updated_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 4 - assert service_removed_count == 0 - - # service removed - _inject_response(zeroconf, mock_incoming_msg(r.ServiceStateChange.Removed)) - service_removed_event.wait(wait_time) - assert service_added_count == 1 - assert service_updated_count == 4 - assert service_removed_count == 1 - - finally: - assert len(zeroconf.listeners) == 1 - service_browser.cancel() - assert len(zeroconf.listeners) == 0 - zeroconf.remove_all_service_listeners() - zeroconf.close() - - def test_multiple_addresses(): type_ = "_http._tcp.local." registration_name = "xxxyyy.%s" % type_ @@ -974,168 +670,6 @@ async def test_multiple_a_addresses(): await aiozc.async_close() -def test_backoff(): - got_query = Event() - - type_ = "_http._tcp.local." - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to patch the zeroconf send to check query transmission - old_send = zeroconf_browser.send - - time_offset = 0.0 - start_time = time.time() * 1000 - initial_query_interval = _services_browser._BROWSER_TIME / 1000 - - def current_time_millis(): - """Current system time in milliseconds""" - return start_time + time_offset * 1000 - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): - """Sends an outgoing packet.""" - got_query.set() - old_send(out, addr=addr, port=port) - - # patch the zeroconf send - # patch the zeroconf current_time_millis - # patch the backoff limit to prevent test running forever - with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( - _services_browser, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10): - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - try: - # Test that queries are sent at increasing intervals - sleep_count = 0 - next_query_interval = 0.0 - expected_query_time = 0.0 - while True: - sleep_count += 1 - for _ in range(2): - # If the browser thread is starting up - # its possible we notify before the initial sleep - # which means the test will fail so we need to d - # this twice to eliminate the race condition - zeroconf_browser.notify_all() - got_query.wait(0.05) - if time_offset == expected_query_time: - assert got_query.is_set() - got_query.clear() - if next_query_interval == _services_browser._BROWSER_BACKOFF_LIMIT: - # Only need to test up to the point where we've seen a query - # after the backoff limit has been hit - break - elif next_query_interval == 0: - next_query_interval = initial_query_interval - expected_query_time = initial_query_interval - else: - next_query_interval = min( - 2 * next_query_interval, _services_browser._BROWSER_BACKOFF_LIMIT - ) - expected_query_time += next_query_interval - else: - assert not got_query.is_set() - time_offset += initial_query_interval - - finally: - browser.cancel() - zeroconf_browser.close() - - -def test_integration(): - service_added = Event() - service_removed = Event() - unexpected_ttl = Event() - got_query = Event() - - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - - def on_service_state_change(zeroconf, service_type, state_change, name): - if name == registration_name: - if state_change is ServiceStateChange.Added: - service_added.set() - elif state_change is ServiceStateChange.Removed: - service_removed.set() - - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - - # we are going to patch the zeroconf send to check packet sizes - old_send = zeroconf_browser.send - - time_offset = 0.0 - - def current_time_millis(): - """Current system time in milliseconds""" - return time.time() * 1000 + time_offset * 1000 - - expected_ttl = const._DNS_HOST_TTL - - nbr_answers = 0 - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): - """Sends an outgoing packet.""" - pout = r.DNSIncoming(out.packets()[0]) - nonlocal nbr_answers - for answer in pout.answers: - nbr_answers += 1 - if not answer.ttl > expected_ttl / 2: - unexpected_ttl.set() - - got_query.set() - old_send(out, addr=addr, port=port) - - # patch the zeroconf send - # patch the zeroconf current_time_millis - # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( - _services_browser, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): - service_added = Event() - service_removed = Event() - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - zeroconf_registrar.register_service(info) - - try: - service_added.wait(1) - assert service_added.is_set() - - # Test that we receive queries containing answers only if the remaining TTL - # is greater than half the original TTL - sleep_count = 0 - test_iterations = 50 - while nbr_answers < test_iterations: - # Increase simulated time shift by 1/4 of the TTL in seconds - time_offset += expected_ttl / 4 - zeroconf_browser.notify_all() - sleep_count += 1 - got_query.wait(0.1) - got_query.clear() - # Prevent the test running indefinitely in an error condition - assert sleep_count < test_iterations * 4 - assert not unexpected_ttl.is_set() - - # Don't remove service, allow close() to cleanup - - finally: - zeroconf_registrar.close() - service_removed.wait(1) - assert service_removed.is_set() - browser.cancel() - zeroconf_browser.close() - - def test_legacy_record_update_listener(): """Test a RecordUpdateListener that does not implement update_records.""" @@ -1215,179 +749,6 @@ def dns_addresses_to_addresses(dns_address: List[DNSAddress]): assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] -def test_service_browser_is_aware_of_port_changes(): - """Test that the ServiceBrowser is aware of port changes.""" - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - # start a browser - type_ = "_hap._tcp.local." - registration_name = "xxxyyy.%s" % type_ - - callbacks = [] - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - nonlocal callbacks - if name == registration_name: - callbacks.append((service_type, state_change, name)) - - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) - - def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - for record in records: - generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packets()[0]) - - _inject_response( - zc, - mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), - ) - time.sleep(0.1) - - assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')] - assert zc.get_service_info(type_, registration_name).port == 80 - - info.port = 400 - _inject_response( - zc, - mock_incoming_msg([info.dns_service()]), - ) - time.sleep(0.1) - - assert callbacks == [ - ('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'), - ('_hap._tcp.local.', ServiceStateChange.Updated, 'xxxyyy._hap._tcp.local.'), - ] - assert zc.get_service_info(type_, registration_name).port == 400 - browser.cancel() - - zc.close() - - -def test_service_browser_listeners_update_service(): - """Test that the ServiceBrowser ServiceListener that implements update_service.""" - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - # start a browser - type_ = "_hap._tcp.local." - registration_name = "xxxyyy.%s" % type_ - callbacks = [] - - class MyServiceListener(r.ServiceListener): - def add_service(self, zc, type_, name) -> None: - nonlocal callbacks - if name == registration_name: - callbacks.append(("add", type_, name)) - - def remove_service(self, zc, type_, name) -> None: - nonlocal callbacks - if name == registration_name: - callbacks.append(("remove", type_, name)) - - def update_service(self, zc, type_, name) -> None: - nonlocal callbacks - if name == registration_name: - callbacks.append(("update", type_, name)) - - listener = MyServiceListener() - - browser = r.ServiceBrowser(zc, type_, None, listener) - - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) - - def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - for record in records: - generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packets()[0]) - - _inject_response( - zc, - mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), - ) - time.sleep(0.2) - info.port = 400 - _inject_response( - zc, - mock_incoming_msg([info.dns_service()]), - ) - time.sleep(0.2) - - assert callbacks == [ - ('add', type_, registration_name), - ('update', type_, registration_name), - ] - browser.cancel() - - zc.close() - - -def test_service_browser_listeners_no_update_service(): - """Test that the ServiceBrowser ServiceListener that does not implement update_service.""" - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - # start a browser - type_ = "_hap._tcp.local." - registration_name = "xxxyyy.%s" % type_ - callbacks = [] - - class MyServiceListener: - def add_service(self, zc, type_, name) -> None: - nonlocal callbacks - if name == registration_name: - callbacks.append(("add", type_, name)) - - def remove_service(self, zc, type_, name) -> None: - nonlocal callbacks - if name == registration_name: - callbacks.append(("remove", type_, name)) - - listener = MyServiceListener() - - browser = r.ServiceBrowser(zc, type_, None, listener) - - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) - - def mock_incoming_msg(records) -> r.DNSIncoming: - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - for record in records: - generated.add_answer_at_time(record, 0) - return r.DNSIncoming(generated.packets()[0]) - - _inject_response( - zc, - mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]), - ) - time.sleep(0.2) - info.port = 400 - _inject_response( - zc, - mock_incoming_msg([info.dns_service()]), - ) - time.sleep(0.2) - - assert callbacks == [ - ('add', type_, registration_name), - ] - browser.cancel() - - zc.close() - - def test_changing_name_updates_serviceinfo_key(): """Verify a name change will adjust the underlying key value.""" type_ = "_homeassistant._tcp.local." @@ -1407,23 +768,6 @@ def test_changing_name_updates_serviceinfo_key(): assert info_service.key == "yourtesthome._homeassistant._tcp.local." -def test_servicebrowser_uses_non_strict_names(): - """Verify we can look for technically invalid names as we cannot change what others do.""" - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - browser = ServiceBrowser(zc, ["_tivo-videostream._tcp.local."], [on_service_state_change]) - browser.cancel() - - # Still fail on completely invalid - with pytest.raises(r.BadTypeInNameException): - browser = ServiceBrowser(zc, ["tivo-videostream._tcp.local."], [on_service_state_change]) - zc.close() - - def test_servicelisteners_raise_not_implemented(): """Verify service listeners raise when one of the methods is not implemented.""" @@ -1524,26 +868,3 @@ def test_serviceinfo_accepts_bytes_or_string_dict(): addresses=addresses, ) assert info_service.dns_text().text == b'\x0epath=/~paulsm/' - - -def test_group_ptr_queries_with_known_answers(): - questions_with_known_answers: _services_browser._QuestionWithKnownAnswers = {} - now = current_time_millis() - for i in range(120): - name = f"_hap{i}._tcp._local." - questions_with_known_answers[DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN)] = set( - DNSPointer( - name, - const._TYPE_PTR, - const._CLASS_IN, - 4500, - f"zoo{counter}.{name}", - ) - for counter in range(i) - ) - outs = _services_browser._group_ptr_queries_with_known_answers(now, True, questions_with_known_answers) - for out in outs: - packets = out.packets() - # If we generate multiple packets there must - # only be one question - assert len(packets) == 1 or len(out.questions) == 1 From 541292e55fee8bbafe687afcb8d152f6fe0efb5f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 22:14:54 -1000 Subject: [PATCH 389/608] Relocate service info tests to tests/services/test_info.py (#746) --- tests/services/test_info.py | 627 ++++++++++++++++++++++++++++++++++++ tests/test_services.py | 597 +--------------------------------- 2 files changed, 629 insertions(+), 595 deletions(-) create mode 100644 tests/services/test_info.py diff --git a/tests/services/test_info.py b/tests/services/test_info.py new file mode 100644 index 00000000..8f654a70 --- /dev/null +++ b/tests/services/test_info.py @@ -0,0 +1,627 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._services.info. """ + +import logging +import socket +import threading +import os +import unittest +from threading import Event +from typing import List + +import pytest + +import zeroconf as r +from zeroconf import DNSAddress, const +from zeroconf._services.info import ServiceInfo +from zeroconf.aio import AsyncZeroconf + +from .. import has_working_ipv6, _inject_response + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +class TestServiceInfo(unittest.TestCase): + def test_get_name(self): + """Verify the name accessor can strip the type.""" + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + assert info.get_name() == "name" + + def test_service_info_rejects_non_matching_updates(self): + """Verify records with the wrong name are rejected.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Verify backwards compatiblity with calling with None + info.update_record(zc, now, None) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.3") + info.update_record( + zc, + now, + r.DNSAddress( + 'ASH-2.local.', + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address in info.addresses + # Non-matching updates + info.update_record( + zc, + now, + r.DNSText( + "incorrect.name.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + info.update_record( + zc, + now, + r.DNSService( + "incorrect.name.", + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + 'ASH-2.local.', + ), + ) + assert info.server_key == 'ash-2.local.' + assert info.server == 'ASH-2.local.' + new_address = socket.inet_aton("10.0.1.4") + info.update_record( + zc, + now, + r.DNSAddress( + "incorrect.name.", + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + new_address, + ), + ) + assert new_address not in info.addresses + zc.close() + + def test_service_info_rejects_expired_records(self): + """Verify records that are expired are rejected.""" + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + # Matching updates + info.update_record( + zc, + now, + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', + ), + ) + assert info.properties[b"ci"] == b"2" + # Expired record + expired_record = r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', + ) + expired_record.created = 1000 + expired_record._expiration_time = 1000 + info.update_record(zc, now, expired_record) + assert info.properties[b"ci"] == b"2" + zc.close() + + def test_get_info_partial(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # patch the zeroconf send + with unittest.mock.patch.object(zc, "send", send): + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packets()[0]) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for SRV, A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 3 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext query for A, AAAA + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 2 + assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + last_sent = None + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ) + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + def test_get_info_single(self): + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_text = b'path=/~matt1/' + service_address = '10.0.1.2' + + service_info = None + send_event = Event() + service_info_event = Event() + + last_sent = None # type: Optional[r.DNSOutgoing] + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal last_sent + + last_sent = out + send_event.set() + + # patch the zeroconf send + with unittest.mock.patch.object(zc, "send", send): + + def mock_incoming_msg(records) -> r.DNSIncoming: + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for record in records: + generated.add_answer_at_time(record, 0) + + return r.DNSIncoming(generated.packets()[0]) + + def get_service_info_helper(zc, type, name): + nonlocal service_info + service_info = zc.get_service_info(type, name) + service_info_event.set() + + try: + ttl = 120 + helper_thread = threading.Thread( + target=get_service_info_helper, args=(zc, service_type, service_name) + ) + helper_thread.start() + wait_time = 1 + + # Expext query for SRV, TXT, A, AAAA + send_event.wait(wait_time) + assert last_sent is not None + assert len(last_sent.questions) == 4 + assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions + assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions + assert service_info is None + + # Expext no further queries + last_sent = None + send_event.clear() + _inject_response( + zc, + mock_incoming_msg( + [ + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + service_text, + ), + r.DNSService( + service_name, + const._TYPE_SRV, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + 0, + 0, + 80, + service_server, + ), + r.DNSAddress( + service_server, + const._TYPE_A, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET, service_address), + ), + ] + ), + ) + send_event.wait(wait_time) + assert last_sent is None + assert service_info is not None + + finally: + helper_thread.join() + zc.remove_all_service_listeners() + zc.close() + + +def test_multiple_addresses(): + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) + + assert info.addresses == [address, address] + + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_parsed], + ) + assert info.addresses == [address, address] + + if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): + address_v6_parsed = "2001:db8::1" + address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) + infos = [ + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[address, address_v6], + ), + ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + parsed_addresses=[address_parsed, address_v6_parsed], + ), + ] + for info in infos: + assert info.addresses == [address] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.V4Only) == [address] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed] + assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_multiple_a_addresses(): + type_ = "_http._tcp.local." + registration_name = "multiarec.%s" % type_ + desc = {'path': '/~paulsm/'} + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + cache = aiozc.zeroconf.cache + host = "multahost.local." + record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a') + record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b') + cache.async_add_records([record1, record2]) + + # New kwarg way + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) + info.load_from_cache(aiozc.zeroconf) + assert set(info.addresses) == set([b'a', b'b']) + await aiozc.async_close() + + +def test_filter_address_by_type_from_service_info(): + """Verify dns_addresses can filter by ipversion.""" + desc = {'path': '/~paulsm/'} + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + registration_name = "%s.%s" % (name, type_) + ipv4 = socket.inet_aton("10.0.1.2") + ipv6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[ipv4, ipv6]) + + def dns_addresses_to_addresses(dns_address: List[DNSAddress]): + return [address.address for address in dns_address] + + assert dns_addresses_to_addresses(info.dns_addresses()) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4] + assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] + + +def test_changing_name_updates_serviceinfo_key(): + """Verify a name change will adjust the underlying key value.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + assert info_service.key == "mytesthome._homeassistant._tcp.local." + info_service.name = "YourTestHome._homeassistant._tcp.local." + assert info_service.key == "yourtesthome._homeassistant._tcp.local." + + +def test_serviceinfo_address_updates(): + """Verify adding/removing/setting addresses on ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + + # Verify addresses and parsed_addresses are mutually exclusive + with pytest.raises(TypeError): + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + parsed_addresses=["10.0.1.2"], + ) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info_service.addresses = [socket.inet_aton("10.0.1.3")] + assert info_service.addresses == [socket.inet_aton("10.0.1.3")] + + +def test_serviceinfo_accepts_bytes_or_string_dict(): + """Verify a bytes or string dict can be passed to ServiceInfo.""" + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + addresses = [socket.inet_aton("10.0.1.2")] + server_name = "ash-2.local." + info_service = ServiceInfo( + type_, '%s.%s' % (name, type_), 80, 0, 0, {b'path': b'/~paulsm/'}, server_name, addresses=addresses + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {b'path': '/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': b'/~paulsm/'}, + server_name, + addresses=addresses, + ) + assert info_service.dns_text().text == b'\x0epath=/~paulsm/' diff --git a/tests/test_services.py b/tests/test_services.py index a22d6f6b..1a3ada23 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -6,23 +6,20 @@ import logging import socket -import threading import time import os import unittest from threading import Event -from typing import List import pytest import zeroconf as r -from zeroconf import DNSAddress, const +from zeroconf import const from zeroconf import Zeroconf from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo -from zeroconf.aio import AsyncZeroconf -from . import has_working_ipv6, _clear_cache, _inject_response +from . import has_working_ipv6, _clear_cache log = logging.getLogger('zeroconf') @@ -40,400 +37,6 @@ def teardown_module(): log.setLevel(original_logging_level) -class TestServiceInfo(unittest.TestCase): - def test_get_name(self): - """Verify the name accessor can strip the type.""" - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - assert info.get_name() == "name" - - def test_service_info_rejects_non_matching_updates(self): - """Verify records with the wrong name are rejected.""" - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - ttl = 120 - now = r.current_time_millis() - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - # Verify backwards compatiblity with calling with None - info.update_record(zc, now, None) - # Matching updates - info.update_record( - zc, - now, - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - info.update_record( - zc, - now, - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - 'ASH-2.local.', - ), - ) - assert info.server_key == 'ash-2.local.' - assert info.server == 'ASH-2.local.' - new_address = socket.inet_aton("10.0.1.3") - info.update_record( - zc, - now, - r.DNSAddress( - 'ASH-2.local.', - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - new_address, - ), - ) - assert new_address in info.addresses - # Non-matching updates - info.update_record( - zc, - now, - r.DNSText( - "incorrect.name.", - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - info.update_record( - zc, - now, - r.DNSService( - "incorrect.name.", - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - 'ASH-2.local.', - ), - ) - assert info.server_key == 'ash-2.local.' - assert info.server == 'ASH-2.local.' - new_address = socket.inet_aton("10.0.1.4") - info.update_record( - zc, - now, - r.DNSAddress( - "incorrect.name.", - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - new_address, - ), - ) - assert new_address not in info.addresses - zc.close() - - def test_service_info_rejects_expired_records(self): - """Verify records that are expired are rejected.""" - zc = r.Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_address = socket.inet_aton("10.0.1.2") - ttl = 120 - now = r.current_time_millis() - info = ServiceInfo( - service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] - ) - # Matching updates - info.update_record( - zc, - now, - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==', - ), - ) - assert info.properties[b"ci"] == b"2" - # Expired record - expired_record = r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', - ) - expired_record.created = 1000 - expired_record._expiration_time = 1000 - info.update_record(zc, now, expired_record) - assert info.properties[b"ci"] == b"2" - zc.close() - - def test_get_info_partial(self): - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - - service_info = None - send_event = Event() - service_info_event = Event() - - last_sent = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal last_sent - - last_sent = out - send_event.set() - - # patch the zeroconf send - with unittest.mock.patch.object(zc, "send", send): - - def mock_incoming_msg(records) -> r.DNSIncoming: - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - - for record in records: - generated.add_answer_at_time(record, 0) - - return r.DNSIncoming(generated.packets()[0]) - - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() - - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for SRV, A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - service_text, - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 3 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext query for A, AAAA - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 2 - assert r.DNSQuestion(service_server, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_server, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - last_sent = None - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ) - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None - - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() - - def test_get_info_single(self): - - zc = r.Zeroconf(interfaces=['127.0.0.1']) - - service_name = 'name._type._tcp.local.' - service_type = '_type._tcp.local.' - service_server = 'ash-1.local.' - service_text = b'path=/~matt1/' - service_address = '10.0.1.2' - - service_info = None - send_event = Event() - service_info_event = Event() - - last_sent = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): - """Sends an outgoing packet.""" - nonlocal last_sent - - last_sent = out - send_event.set() - - # patch the zeroconf send - with unittest.mock.patch.object(zc, "send", send): - - def mock_incoming_msg(records) -> r.DNSIncoming: - - generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) - - for record in records: - generated.add_answer_at_time(record, 0) - - return r.DNSIncoming(generated.packets()[0]) - - def get_service_info_helper(zc, type, name): - nonlocal service_info - service_info = zc.get_service_info(type, name) - service_info_event.set() - - try: - ttl = 120 - helper_thread = threading.Thread( - target=get_service_info_helper, args=(zc, service_type, service_name) - ) - helper_thread.start() - wait_time = 1 - - # Expext query for SRV, TXT, A, AAAA - send_event.wait(wait_time) - assert last_sent is not None - assert len(last_sent.questions) == 4 - assert r.DNSQuestion(service_name, const._TYPE_SRV, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_TXT, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_A, const._CLASS_IN) in last_sent.questions - assert r.DNSQuestion(service_name, const._TYPE_AAAA, const._CLASS_IN) in last_sent.questions - assert service_info is None - - # Expext no further queries - last_sent = None - send_event.clear() - _inject_response( - zc, - mock_incoming_msg( - [ - r.DNSText( - service_name, - const._TYPE_TXT, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - service_text, - ), - r.DNSService( - service_name, - const._TYPE_SRV, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - 0, - 0, - 80, - service_server, - ), - r.DNSAddress( - service_server, - const._TYPE_A, - const._CLASS_IN | const._CLASS_UNIQUE, - ttl, - socket.inet_pton(socket.AF_INET, service_address), - ), - ] - ), - ) - send_event.wait(wait_time) - assert last_sent is None - assert service_info is not None - - finally: - helper_thread.join() - zc.remove_all_service_listeners() - zc.close() - - class ListenerTest(unittest.TestCase): def test_integration_with_listener_class(self): @@ -590,86 +193,6 @@ def update_service(self, zeroconf, type, name): zeroconf_browser.close() -def test_multiple_addresses(): - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - desc = {'path': '/~paulsm/'} - address_parsed = "10.0.1.2" - address = socket.inet_aton(address_parsed) - - # New kwarg way - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) - - assert info.addresses == [address, address] - - info = ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - parsed_addresses=[address_parsed, address_parsed], - ) - assert info.addresses == [address, address] - - if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): - address_v6_parsed = "2001:db8::1" - address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) - infos = [ - ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[address, address_v6], - ), - ServiceInfo( - type_, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - parsed_addresses=[address_parsed, address_v6_parsed], - ), - ] - for info in infos: - assert info.addresses == [address] - assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] - assert info.addresses_by_version(r.IPVersion.V4Only) == [address] - assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] - assert info.parsed_addresses() == [address_parsed, address_v6_parsed] - assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] - - -# This test uses asyncio because it needs to access the cache directly -# which is not threadsafe -@pytest.mark.asyncio -async def test_multiple_a_addresses(): - type_ = "_http._tcp.local." - registration_name = "multiarec.%s" % type_ - desc = {'path': '/~paulsm/'} - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - cache = aiozc.zeroconf.cache - host = "multahost.local." - record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a') - record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b') - cache.async_add_records([record1, record2]) - - # New kwarg way - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host) - info.load_from_cache(aiozc.zeroconf) - assert set(info.addresses) == set([b'a', b'b']) - await aiozc.async_close() - - def test_legacy_record_update_listener(): """Test a RecordUpdateListener that does not implement update_records.""" @@ -730,44 +253,6 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.close() -def test_filter_address_by_type_from_service_info(): - """Verify dns_addresses can filter by ipversion.""" - desc = {'path': '/~paulsm/'} - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - registration_name = "%s.%s" % (name, type_) - ipv4 = socket.inet_aton("10.0.1.2") - ipv6 = socket.inet_pton(socket.AF_INET6, "2001:db8::1") - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[ipv4, ipv6]) - - def dns_addresses_to_addresses(dns_address: List[DNSAddress]): - return [address.address for address in dns_address] - - assert dns_addresses_to_addresses(info.dns_addresses()) == [ipv4, ipv6] - assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.All)) == [ipv4, ipv6] - assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V4Only)) == [ipv4] - assert dns_addresses_to_addresses(info.dns_addresses(version=r.IPVersion.V6Only)) == [ipv6] - - -def test_changing_name_updates_serviceinfo_key(): - """Verify a name change will adjust the underlying key value.""" - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - assert info_service.key == "mytesthome._homeassistant._tcp.local." - info_service.name = "YourTestHome._homeassistant._tcp.local." - assert info_service.key == "yourtesthome._homeassistant._tcp.local." - - def test_servicelisteners_raise_not_implemented(): """Verify service listeners raise when one of the methods is not implemented.""" @@ -790,81 +275,3 @@ class MyPartialListener(r.ServiceListener): ) zc.close() - - -def test_serviceinfo_address_updates(): - """Verify adding/removing/setting addresses on ServiceInfo.""" - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - - # Verify addresses and parsed_addresses are mutually exclusive - with pytest.raises(TypeError): - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - parsed_addresses=["10.0.1.2"], - ) - - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - info_service.addresses = [socket.inet_aton("10.0.1.3")] - assert info_service.addresses == [socket.inet_aton("10.0.1.3")] - - -def test_serviceinfo_accepts_bytes_or_string_dict(): - """Verify a bytes or string dict can be passed to ServiceInfo.""" - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - addresses = [socket.inet_aton("10.0.1.2")] - server_name = "ash-2.local." - info_service = ServiceInfo( - type_, '%s.%s' % (name, type_), 80, 0, 0, {b'path': b'/~paulsm/'}, server_name, addresses=addresses - ) - assert info_service.dns_text().text == b'\x0epath=/~paulsm/' - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - server_name, - addresses=addresses, - ) - assert info_service.dns_text().text == b'\x0epath=/~paulsm/' - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {b'path': '/~paulsm/'}, - server_name, - addresses=addresses, - ) - assert info_service.dns_text().text == b'\x0epath=/~paulsm/' - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': b'/~paulsm/'}, - server_name, - addresses=addresses, - ) - assert info_service.dns_text().text == b'\x0epath=/~paulsm/' From 0909c80c67287ba92ed334ab6896136aec0f3f24 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 18 Jun 2021 22:23:56 -1000 Subject: [PATCH 390/608] Update changelog (#747) --- README.rst | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/README.rst b/README.rst index 3e3d0e64..3b7dbf22 100644 --- a/README.rst +++ b/README.rst @@ -254,12 +254,71 @@ you can likely not be concerned with the breaking changes below: * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Relocate service browser tests to tests/services/test_browser.py (#745) @bdraco + +* Relocate ServiceInfo to zeroconf._services.info (#741) @bdraco + +* Run question answer callbacks from add_listener in the event loop (#740) @bdraco + + Calling async_update_records and async_update_records_complete should always + happen in the event loop to ensure implementers do not need to worry about + thread safety + +* Remove second level caching from ServiceBrowsers (#737) @bdraco + + The ServiceBrowser had its own cache of the last time it + saw a service which was reimplementing the DNSCache and + presenting a source of truth problem that lead to unexpected + queries when the two disagreed. + +* Breakout ServiceBrowser handler from listener creation (#736) @bdraco + + Add coverage for the handler from listener + +* Add fast cache lookup functions (#732) @bdraco + + The majority of our lookups happen in the event loop so there is no need + for them to be threadsafe. Now that the codebase is more clear about what + needs to be threadsafe and what does not need to be threadsafe we can use + the much faster non-threadsafe versions in the places where we are calling + from the event loop. + +* Switch to using DNSRRSet in RecordManager (#735) @bdraco + + DNSRRSet is able to do O(1) lookups of records assuming + there are no collisions. + +* Fix server cache to be case-insensitive (#731) @bdraco + + If the server name had uppercase chars and any of the + matching records were lowercase, the server would not be + found + * Fix cache handling of records with different TTLs (#729) @bdraco + There should only be one unique record in the cache at + a time as having multiple unique records will different + TTLs in the cache can result in unexpected behavior since + some functions returned all matching records and some + fetched from the right side of the list to return the + newest record. Intead we now store the records in a dict + to ensure that the newest record always replaces the same + unique record and we never have a source of truth problem + determining the TTL of a record from the cache. + * Rename handlers and internals to make it clear what is threadsafe (#726) @bdraco + It was too easy to get confused about what was threadsafe and + what was not threadsafe which lead to unexpected failures. + Rename functions to make it clear what will be run in the event + loop and what is expected to be threadsafe + * Fix ServiceInfo with multiple A records (#725) @bdraco + If there were multiple A records for the host, ServiceInfo + would always return the last one that was in the incoming + packet which was usually not the one that was wanted. + * Synchronize time for fate sharing (#718) @bdraco * Cleanup typing in zero._core and document ignores (#714) @bdraco From 7b3b4b5b8303a684165fcd53c0d9c36a1b8dda3d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 08:30:32 -1000 Subject: [PATCH 391/608] Remove support for notify listeners (#733) --- tests/test_aio.py | 6 +++--- tests/test_core.py | 44 -------------------------------------- zeroconf/__init__.py | 2 +- zeroconf/_core.py | 48 ++++++++++++++++++++++-------------------- zeroconf/_handlers.py | 2 +- zeroconf/_utils/aio.py | 24 ++++++++++++++++++++- zeroconf/aio.py | 36 +++++-------------------------- 7 files changed, 58 insertions(+), 104 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index e4144250..bd4f7d2d 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -312,12 +312,12 @@ async def test_async_wait_unblocks_on_update() -> None: # Should unblock due to update from the # registration now = current_time_millis() - await aiozc.async_wait(50000) + await aiozc.zeroconf.async_wait(50000) assert current_time_millis() - now < 3000 await task now = current_time_millis() - await aiozc.async_wait(50) + await aiozc.zeroconf.async_wait(50) assert current_time_millis() - now < 1000 await aiozc.async_close() @@ -481,7 +481,7 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: await task task = await aiozc.async_unregister_service(new_info) await task - await aiozc.async_wait(1) + await aiozc.zeroconf.async_wait(1) await aiozc.async_close() assert calls == [ diff --git a/tests/test_core.py b/tests/test_core.py index 1f0884f0..13c6ec70 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -249,50 +249,6 @@ def mock_split_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNS zeroconf.close() -# This test uses asyncio because it needs to verify the listeners -# run in the event loop -@pytest.mark.asyncio -async def test_notify_listeners(): - """Test adding and removing notify listeners.""" - # instantiate a zeroconf instance - aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) - zc = aiozc.zeroconf - notify_called = 0 - - class TestNotifyListener(r.NotifyListener): - def notify_all(self): - nonlocal notify_called - notify_called += 1 - - with pytest.raises(NotImplementedError): - r.NotifyListener().notify_all() - - notify_listener = TestNotifyListener() - - zc.add_notify_listener(notify_listener) - - def on_service_state_change(zeroconf, service_type, state_change, name): - """Dummy service callback.""" - - # start a browser - browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) - browser.cancel() - - await asyncio.sleep(0) # flush out any call_soon_threadsafe - assert notify_called - zc.remove_notify_listener(notify_listener) - - notify_called = 0 - # start a browser - browser = ServiceBrowser(zc, "_http._tcp.local.", [on_service_state_change]) - browser.cancel() - await asyncio.sleep(0) # flush out any call_soon_threadsafe - - assert not notify_called - - await aiozc.async_close() - - def test_generate_service_query_set_qu_bit(): """Test generate_service_query sets the QU bit.""" diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3715e174..e3d7ddfb 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -23,7 +23,7 @@ import sys from ._cache import DNSCache # noqa # import needed for backwards compat -from ._core import NotifyListener, Zeroconf # noqa # import needed for backwards compat +from ._core import Zeroconf # noqa # import needed for backwards compat from ._dns import ( # noqa # import needed for backwards compat DNSAddress, DNSEntry, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index ff54dc7e..1ef3a843 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -41,7 +41,7 @@ from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry -from ._utils.aio import get_running_loop +from ._utils.aio import get_running_loop, shutdown_loop, wait_condition_or_timeout from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -72,14 +72,6 @@ _TC_DELAY_RANDOM_INTERVAL = (400, 500) -class NotifyListener: - """Receive notifications Zeroconf.notify_all is called.""" - - def notify_all(self) -> None: - """Called when Zeroconf.notify_all is called.""" - raise NotImplementedError() - - class AsyncEngine: """An engine wraps sockets in the event loop.""" @@ -293,7 +285,6 @@ def __init__( self.engine = AsyncEngine(self, listen_socket, respond_sockets) - self._notify_listeners: List[NotifyListener] = [] self.browsers: Dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() self.cache = DNSCache() @@ -301,6 +292,7 @@ def __init__( self.record_manager = RecordManager(self) self.condition = threading.Condition() + self.async_condition: Optional[asyncio.Condition] = None self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None @@ -313,6 +305,7 @@ def start(self) -> None: """Start Zeroconf.""" self.loop = get_running_loop() if self.loop: + self.async_condition = asyncio.Condition() self.engine.setup(self.loop, None) return self._start_thread() @@ -324,6 +317,7 @@ def _start_thread(self) -> None: def _run_loop() -> None: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) + self.async_condition = asyncio.Condition() self.engine.setup(self.loop, loop_thread_ready) self.loop.run_forever() @@ -349,12 +343,28 @@ def wait(self, timeout: float) -> None: with self.condition: self.condition.wait(millis_to_seconds(timeout)) + async def async_wait(self, timeout: float) -> None: + """Calling task waits for a given number of milliseconds or until notified.""" + assert self.async_condition is not None + async with self.async_condition: + await wait_condition_or_timeout(self.async_condition, millis_to_seconds(timeout)) + def notify_all(self) -> None: - """Notifies all waiting threads""" + """Notifies all waiting threads and notify listeners.""" + assert self.loop is not None + self.loop.call_soon_threadsafe(self.async_notify_all) + + def async_notify_all(self) -> None: + """Schedule an async_notify_all.""" + asyncio.ensure_future(self._async_notify_all()) + + async def _async_notify_all(self) -> None: + """Notify all async listeners.""" + assert self.async_condition is not None with self.condition: self.condition.notify_all() - for listener in self._notify_listeners: - listener.notify_all() + async with self.async_condition: + self.async_condition.notify_all() def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular @@ -365,15 +375,6 @@ def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Option return info return None - def add_notify_listener(self, listener: NotifyListener) -> None: - """Adds a listener to receive notify_all events.""" - self._notify_listeners.append(listener) - - def remove_notify_listener(self, listener: NotifyListener) -> None: - """Removes a listener from the set that is currently listening.""" - with contextlib.suppress(ValueError): - self._notify_listeners.remove(listener) - def add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object will then have its add_service and remove_service methods called when @@ -652,8 +653,9 @@ def _shutdown_threads(self) -> None: if not self._loop_thread: return assert self.loop is not None - self.loop.call_soon_threadsafe(self.loop.stop) + shutdown_loop(self.loop) self._loop_thread.join() + self._loop_thread = None def close(self) -> None: """Ends the background threads, and prevent this instance from diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 476217f6..db5948c6 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -297,7 +297,7 @@ def async_updates_complete(self) -> None: """ for listener in self.listeners: listener.async_update_records_complete() - self.zc.notify_all() + self.zc.async_notify_all() def async_updates_from_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index 87a79f26..e1bd12d0 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -22,7 +22,7 @@ import asyncio import contextlib -from typing import Optional, cast +from typing import Optional, Set, cast # Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed @@ -54,6 +54,28 @@ def _handle_wait_complete(_: asyncio.Task) -> None: await condition_wait +async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: + """Return all tasks running.""" + if hasattr(asyncio, 'all_tasks'): + return cast(Set[asyncio.Task], asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member + return cast(Set[asyncio.Task], asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member + + +async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: + """Wait for the event loop thread we started to shutdown.""" + await asyncio.wait(wait_tasks, timeout=1) + + +def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: + """Wait for pending tasks and stop an event loop.""" + pending_tasks = asyncio.run_coroutine_threadsafe(_get_all_tasks(loop), loop).result() + done_tasks = set(task for task in pending_tasks if not task.done()) + pending_tasks -= done_tasks + if pending_tasks: + asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result() + loop.call_soon_threadsafe(loop.stop) + + # Remove the call to _get_running_loop once we drop python 3.6 support def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: """Check if an event loop is already running.""" diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 00d42823..a5cf0605 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -24,7 +24,7 @@ from types import TracebackType # noqa # used in type hints from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union -from ._core import NotifyListener, Zeroconf +from ._core import Zeroconf from ._exceptions import NonUniqueNameException from ._services.browser import _ServiceBrowserBase from ._services.info import ServiceInfo, instance_name_from_service_info @@ -52,24 +52,6 @@ ] -class AsyncNotifyListener(NotifyListener): - """A NotifyListener that async code can use to wait for events.""" - - def __init__(self, aiozc: 'AsyncZeroconf') -> None: - """Create an event for async listeners to wait for.""" - self.aiozc = aiozc - self.loop = asyncio.get_event_loop() - - def notify_all(self) -> None: - """Schedule an async_notify_all.""" - self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all()) - - async def _async_notify_all(self) -> None: - """Notify all async listeners.""" - async with self.aiozc.condition: - self.aiozc.condition.notify_all() - - class AsyncServiceListener: def add_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: raise NotImplementedError() @@ -109,7 +91,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: next_ = now + delay delay *= 2 - await aiozc.async_wait(min(next_, last) - now) + await aiozc.zeroconf.async_wait(min(next_, last) - now) now = current_time_millis() finally: aiozc.zeroconf.remove_listener(self) @@ -148,16 +130,17 @@ async def async_cancel(self) -> None: async def async_run(self) -> None: """Run the browser task.""" await self.aiozc.zeroconf.async_wait_for_start() + assert self.aiozc.zeroconf.async_condition is not None while True: timeout = self._seconds_to_wait() if timeout: - async with self.aiozc.condition: + async with self.aiozc.zeroconf.async_condition: # We must check again while holding the condition # in case the other thread has added to _handlers_to_call # between when we checked above when we were not # holding the condition if not self._handlers_to_call: - await wait_condition_or_timeout(self.aiozc.condition, timeout) + await wait_condition_or_timeout(self.aiozc.zeroconf.async_condition, timeout) outs = self.generate_ready_queries() for out in outs: @@ -253,10 +236,7 @@ def __init__( apple_p2p=apple_p2p, ) self.loop = asyncio.get_event_loop() - self.async_notify = AsyncNotifyListener(self) - self.zeroconf.add_notify_listener(self.async_notify) self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} - self.condition = asyncio.Condition() async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -343,7 +323,6 @@ async def async_close(self) -> None: with contextlib.suppress(asyncio.TimeoutError): await asyncio.wait_for(self.zeroconf.async_wait_for_start(), timeout=1) await self.async_remove_all_service_listeners() - self.zeroconf.remove_notify_listener(self.async_notify) await self.async_unregister_all_services() await self.zeroconf._async_close() # pylint: disable=protected-access @@ -358,11 +337,6 @@ async def async_get_service_info( return info return None - async def async_wait(self, timeout: float) -> None: - """Calling task waits for a given number of milliseconds or until notified.""" - async with self.condition: - await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout)) - async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: """Adds a listener for a particular service type. This object will then have its add_service and remove_service methods called when From 0dbcabfade41057a055ebefffd410d1afc3eb0ea Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 09:01:32 -1000 Subject: [PATCH 392/608] Run ServiceInfo requests in the event loop (#748) --- examples/async_service_info_request.py | 2 +- tests/services/test_info.py | 4 +-- tests/test_aio.py | 2 +- zeroconf/_services/info.py | 17 ++++++++---- zeroconf/aio.py | 37 ++------------------------ 5 files changed, 18 insertions(+), 44 deletions(-) diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py index 838545ce..b73f27dc 100644 --- a/examples/async_service_info_request.py +++ b/examples/async_service_info_request.py @@ -28,7 +28,7 @@ async def async_watch_services(aiozc: AsyncZeroconf) -> None: if not name.endswith(HAP_TYPE): continue infos.append(AsyncServiceInfo(HAP_TYPE, name)) - tasks = [info.async_request(aiozc, 3000) for info in infos] + tasks = [info.async_request(aiozc.zeroconf, 3000) for info in infos] await asyncio.gather(*tasks) for info in infos: print("Info for %s" % (info.name)) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 8f654a70..8fb45f22 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -216,7 +216,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): send_event.set() # patch the zeroconf send - with unittest.mock.patch.object(zc, "send", send): + with unittest.mock.patch.object(zc, "async_send", send): def mock_incoming_msg(records) -> r.DNSIncoming: @@ -353,7 +353,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): send_event.set() # patch the zeroconf send - with unittest.mock.patch.object(zc, "send", send): + with unittest.mock.patch.object(zc, "async_send", send): def mock_incoming_msg(records) -> r.DNSIncoming: diff --git a/tests/test_aio.py b/tests/test_aio.py index bd4f7d2d..9b3711fe 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -409,7 +409,7 @@ async def test_service_info_async_request() -> None: # Generating the race condition is almost impossible # without patching since its a TOCTOU race with unittest.mock.patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): - await aiosinfo.async_request(aiozc, 3000) + await aiosinfo.async_request(aiozc.zeroconf, 3000) assert aiosinfo is not None assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index a3536ed1..5a331927 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -20,6 +20,7 @@ USA """ +import asyncio import socket from typing import Dict, List, Optional, TYPE_CHECKING, Union, cast @@ -393,6 +394,13 @@ def _is_complete(self) -> bool: return not (self.text is None or not self._addresses) def request(self, zc: 'Zeroconf', timeout: float) -> bool: + """Returns true if the service could be discovered on the + network, and updates this object with details discovered. + """ + assert zc.loop is not None + return asyncio.run_coroutine_threadsafe(self.async_request(zc, timeout), zc.loop).result() + + async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: """Returns true if the service could be discovered on the network, and updates this object with details discovered. """ @@ -403,9 +411,8 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: delay = _LISTENER_TIME next_ = now last = now + timeout + await zc.async_wait_for_start() try: - # Do not set a question on the listener to preload from cache - # since we just checked it above in load_from_cache zc.add_listener(self, None) while not self._is_complete: if last <= now: @@ -413,12 +420,12 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: if next_ <= now: out = self.generate_request_query(zc, now) if not out.questions: - return True - zc.send(out) + return self.load_from_cache(zc) + zc.async_send(out) next_ = now + delay delay *= 2 - zc.wait(min(next_, last) - now) + await zc.async_wait(min(next_, last) - now) now = current_time_millis() finally: zc.remove_listener(self) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index a5cf0605..a0c908b6 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -31,11 +31,10 @@ from ._services.types import ZeroconfServiceTypes from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType -from ._utils.time import current_time_millis, millis_to_seconds +from ._utils.time import millis_to_seconds from .const import ( _BROWSER_TIME, _CHECK_TIME, - _LISTENER_TIME, _MDNS_PORT, _REGISTER_TIME, _SERVICE_TYPE_ENUMERATION_NAME, @@ -66,38 +65,6 @@ def update_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: class AsyncServiceInfo(ServiceInfo): """An async version of ServiceInfo.""" - async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool: - """Returns true if the service could be discovered on the - network, and updates this object with details discovered. - """ - if self.load_from_cache(aiozc.zeroconf): - return True - - now = current_time_millis() - delay = _LISTENER_TIME - next_ = now - last = now + timeout - await aiozc.zeroconf.async_wait_for_start() - try: - aiozc.zeroconf.add_listener(self, None) - while not self._is_complete: - if last <= now: - return False - if next_ <= now: - out = self.generate_request_query(aiozc.zeroconf, now) - if not out.questions: - return self.load_from_cache(aiozc.zeroconf) - aiozc.zeroconf.async_send(out) - next_ = now + delay - delay *= 2 - - await aiozc.zeroconf.async_wait(min(next_, last) - now) - now = current_time_millis() - finally: - aiozc.zeroconf.remove_listener(self) - - return True - class AsyncServiceBrowser(_ServiceBrowserBase): """Used to browse for a service of a specific type. @@ -333,7 +300,7 @@ async def async_get_service_info( name and type, or None if no service matches by the timeout, which defaults to 3 seconds.""" info = AsyncServiceInfo(type_, name) - if await info.async_request(self, timeout): + if await info.async_request(self.zeroconf, timeout): return info return None From 0f702c6a41bb33ed63872249b82d1111bdac4fa6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 09:11:38 -1000 Subject: [PATCH 393/608] Update async_service_info_request example to ensure it runs in the right event loop (#749) --- examples/async_service_info_request.py | 48 +++++++++++++++++--------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py index b73f27dc..8ea961eb 100644 --- a/examples/async_service_info_request.py +++ b/examples/async_service_info_request.py @@ -9,7 +9,7 @@ import argparse import asyncio import logging -from typing import cast +from typing import Any, Optional, cast from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf @@ -48,6 +48,33 @@ async def async_watch_services(aiozc: AsyncZeroconf) -> None: print('\n') +class AsyncRunner: + def __init__(self, args: Any) -> None: + self.args = args + self.threaded_browser: Optional[ServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + assert self.aiozc is not None + + def on_service_state_change( + zeroconf: Zeroconf, service_type: str, state_change: ServiceStateChange, name: str + ) -> None: + """Dummy handler.""" + + self.threaded_browser = ServiceBrowser( + self.aiozc.zeroconf, [HAP_TYPE], handlers=[on_service_state_change] + ) + await async_watch_services(self.aiozc) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.threaded_browser is not None + self.threaded_browser.cancel() + await self.aiozc.async_close() + + if __name__ == '__main__': logging.basicConfig(level=logging.DEBUG) @@ -67,23 +94,10 @@ async def async_watch_services(aiozc: AsyncZeroconf) -> None: else: ip_version = IPVersion.V4Only - aiozc = AsyncZeroconf(ip_version=ip_version) - - def on_service_state_change( - zeroconf: Zeroconf, service_type: str, state_change: ServiceStateChange, name: str - ) -> None: - """Dummy handler.""" - print(f"Services with {HAP_TYPE} will be shown every 5s, press Ctrl-C to exit...") - # ServiceBrowser currently is only offered in sync context. - # ServiceInfo has an AsyncServiceInfo counterpart that can be used - # to fetch service info in parallel - browser = ServiceBrowser(aiozc.zeroconf, [HAP_TYPE], handlers=[on_service_state_change]) loop = asyncio.get_event_loop() + runner = AsyncRunner(args) try: - loop.run_until_complete(async_watch_services(aiozc)) + loop.run_until_complete(runner.async_run()) except KeyboardInterrupt: - pass - finally: - browser.cancel() - loop.run_until_complete(aiozc.async_close()) + loop.run_until_complete(runner.async_close()) From 3b9baf07278290b2b4eb8ac5850bccfbd8b107d8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 09:11:47 -1000 Subject: [PATCH 394/608] Fix warning about Zeroconf._async_notify_all not being awaited in sync shutdown (#750) --- zeroconf/_utils/aio.py | 1 + 1 file changed, 1 insertion(+) diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index e1bd12d0..320ff7c3 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -56,6 +56,7 @@ def _handle_wait_complete(_: asyncio.Task) -> None: async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: """Return all tasks running.""" + await asyncio.sleep(0) # flush out any call_soon_threadsafe if hasattr(asyncio, 'all_tasks'): return cast(Set[asyncio.Task], asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member return cast(Set[asyncio.Task], asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member From e7adce2bf6ea0b4af1709369a36421acd9757b4a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 09:28:00 -1000 Subject: [PATCH 395/608] Remove unused argument from AsyncZeroconf (#751) --- zeroconf/aio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index a0c908b6..ffb7b060 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -202,7 +202,6 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) - self.loop = asyncio.get_event_loop() self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: From 4d0a8f3c643a0fc5c3a40420bab96ef18dddaecb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 10:55:57 -1000 Subject: [PATCH 396/608] Run ServiceBrowser queries in the event loop (#752) --- tests/services/test_browser.py | 8 +-- tests/test_aio.py | 1 + tests/test_init.py | 4 +- zeroconf/_core.py | 14 ++---- zeroconf/_handlers.py | 2 +- zeroconf/_services/browser.py | 91 ++++++++++++++++++++++++++-------- zeroconf/_utils/aio.py | 8 +++ zeroconf/aio.py | 42 ++-------------- 8 files changed, 95 insertions(+), 75 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index ccdb312f..342ced69 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -348,7 +348,7 @@ def test_backoff(): zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) # we are going to patch the zeroconf send to check query transmission - old_send = zeroconf_browser.send + old_send = zeroconf_browser.async_send time_offset = 0.0 start_time = time.time() * 1000 @@ -366,7 +366,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the zeroconf send # patch the zeroconf current_time_millis # patch the backoff limit to prevent test running forever - with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch.object( _services_browser, "current_time_millis", current_time_millis ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10): # dummy service callback @@ -432,7 +432,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) # we are going to patch the zeroconf send to check packet sizes - old_send = zeroconf_browser.send + old_send = zeroconf_browser.async_send time_offset = 0.0 @@ -459,7 +459,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the zeroconf send # patch the zeroconf current_time_millis # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - with unittest.mock.patch.object(zeroconf_browser, "send", send), unittest.mock.patch.object( + with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch.object( _services_browser, "current_time_millis", current_time_millis ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): service_added = Event() diff --git a/tests/test_aio.py b/tests/test_aio.py index 9b3711fe..6cfbcfb2 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -109,6 +109,7 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: calls.append(("update", type, name)) listener = MyListener() + aiozc.zeroconf.add_service_listener(type_, listener) desc = {'path': '/~paulsm/'} diff --git a/tests/test_init.py b/tests/test_init.py index 3cc16b22..30149b84 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -82,7 +82,7 @@ def test_lots_of_names(self): self.verify_name_change(zc, type_, name, server_count) # we are going to patch the zeroconf send to check packet sizes - old_send = zc.send + old_send = zc.async_send longest_packet_len = 0 longest_packet = None # type: Optional[r.DNSOutgoing] @@ -97,7 +97,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zc, "send", send): + with unittest.mock.patch.object(zc, "async_send", send): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1ef3a843..9dab6a27 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -291,7 +291,6 @@ def __init__( self.query_handler = QueryHandler(self.registry, self.cache) self.record_manager = RecordManager(self) - self.condition = threading.Condition() self.async_condition: Optional[asyncio.Condition] = None self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None @@ -338,10 +337,9 @@ def listeners(self) -> List[RecordUpdateListener]: return self.record_manager.listeners def wait(self, timeout: float) -> None: - """Calling thread waits for a given number of milliseconds or - until notified.""" - with self.condition: - self.condition.wait(millis_to_seconds(timeout)) + """Calling task waits for a given number of milliseconds or until notified.""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result() async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" @@ -361,10 +359,8 @@ def async_notify_all(self) -> None: async def _async_notify_all(self) -> None: """Notify all async listeners.""" assert self.async_condition is not None - with self.condition: - self.condition.notify_all() - async with self.async_condition: - self.async_condition.notify_all() + async with self.async_condition: + self.async_condition.notify_all() def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index db5948c6..2a58850d 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -410,7 +410,7 @@ def _async_update_matching_records( return listener.async_update_records(self.zc, now, records) listener.async_update_records_complete() - self.zc.notify_all() + self.zc.async_notify_all() def remove_listener(self, listener: RecordUpdateListener) -> None: """Removes a listener.""" diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index b633df67..1e44679f 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -20,6 +20,9 @@ USA """ +import asyncio +import contextlib +import queue import threading import warnings from collections import OrderedDict @@ -35,6 +38,7 @@ Signal, SignalRegistrationInterface, ) +from .._utils.aio import get_best_available_queue, get_running_loop, wait_condition_or_timeout from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( @@ -180,6 +184,7 @@ def __init__( for check_type_ in self.types: # Will generate BadTypeInNameException on a bad name service_type_name(check_type_, strict=False) + self._browser_task: Optional[asyncio.Task] = None self.zc = zc self.addr = addr self.port = port @@ -190,7 +195,7 @@ def __init__( self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._handlers_to_call: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._service_state_changed = Signal() - + self.queue: Optional[queue.Queue] = None self.done = False if hasattr(handlers, 'add_service'): @@ -341,6 +346,47 @@ def _seconds_to_wait(self) -> Optional[float]: return millis_to_seconds(next_time - now) + async def async_browser_task(self) -> None: + """Run the browser task.""" + await self.zc.async_wait_for_start() + assert self.zc.async_condition is not None + while True: + timeout = self._seconds_to_wait() + if timeout: + async with self.zc.async_condition: + # We must check again while holding the condition + # in case the other thread has added to _handlers_to_call + # between when we checked above when we were not + # holding the condition + if not self._handlers_to_call: + await wait_condition_or_timeout(self.zc.async_condition, timeout) + + outs = self.generate_ready_queries() + for out in outs: + self.zc.async_send(out, addr=self.addr, port=self.port) + + if not self._handlers_to_call: + continue + + (name_type, state_change) = self._handlers_to_call.popitem(False) + if self.queue: + self.queue.put((name_type, state_change)) + continue + + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) + + async def _async_cancel_browser(self) -> None: + """Cancel the browser.""" + assert self._browser_task is not None + self._browser_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._browser_task + class ServiceBrowser(_ServiceBrowserBase, threading.Thread): """Used to browse for a service of a specific type. @@ -361,42 +407,45 @@ def __init__( ) -> None: threading.Thread.__init__(self) super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + self.queue = get_best_available_queue() self.daemon = True self.start() self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), ) + assert self.zc.loop is not None + if get_running_loop() == self.zc.loop: + self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + return + self._browser_task = cast( + asyncio.Task, + asyncio.run_coroutine_threadsafe(self._async_browser_task(), self.zc.loop).result(), + ) + + async def _async_browser_task(self) -> asyncio.Task: + return cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) def cancel(self) -> None: """Cancel the browser.""" + assert self.zc.loop is not None + assert self.queue is not None + self.queue.put(None) + if get_running_loop() == self.zc.loop: + asyncio.ensure_future(self._async_cancel_browser()) + else: + asyncio.run_coroutine_threadsafe(self._async_cancel_browser(), self.zc.loop).result() super().cancel() self.join() def run(self) -> None: """Run the browser thread.""" + assert self.queue is not None while True: - timeout = self._seconds_to_wait() - if timeout: - with self.zc.condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - self.zc.condition.wait(timeout) - - if self.zc.done or self.done: + event = self.queue.get() + if event is None: return - - outs = self.generate_ready_queries() - for out in outs: - self.zc.send(out, addr=self.addr, port=self.port) - - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) + name_type, state_change = event self._service_state_changed.fire( zeroconf=self.zc, service_type=name_type[1], diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index 320ff7c3..d4adffa8 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -22,9 +22,17 @@ import asyncio import contextlib +import queue from typing import Optional, Set, cast +def get_best_available_queue() -> queue.Queue: + """Create the best available queue type.""" + if hasattr(queue, "SimpleQueue"): + return queue.SimpleQueue() # type: ignore # pylint: disable=all + return queue.Queue() + + # Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: """Wait for a condition or timeout.""" diff --git a/zeroconf/aio.py b/zeroconf/aio.py index ffb7b060..e5b6a96a 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -22,14 +22,13 @@ import asyncio import contextlib from types import TracebackType # noqa # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast from ._core import Zeroconf from ._exceptions import NonUniqueNameException from ._services.browser import _ServiceBrowserBase from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.types import ZeroconfServiceTypes -from ._utils.aio import wait_condition_or_timeout from ._utils.net import IPVersion, InterfaceChoice, InterfacesType from ._utils.time import millis_to_seconds from .const import ( @@ -83,46 +82,13 @@ def __init__( port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, ) -> None: - self.aiozc = aiozc super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore - self._browser_task = asyncio.ensure_future(self.async_run()) + self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) async def async_cancel(self) -> None: """Cancel the browser.""" - self.cancel() - self._browser_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._browser_task - - async def async_run(self) -> None: - """Run the browser task.""" - await self.aiozc.zeroconf.async_wait_for_start() - assert self.aiozc.zeroconf.async_condition is not None - while True: - timeout = self._seconds_to_wait() - if timeout: - async with self.aiozc.zeroconf.async_condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - await wait_condition_or_timeout(self.aiozc.zeroconf.async_condition, timeout) - - outs = self.generate_ready_queries() - for out in outs: - self.aiozc.zeroconf.async_send(out, addr=self.addr, port=self.port) - - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) - self._service_state_changed.fire( - zeroconf=self.aiozc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) + await self._async_cancel_browser() + super().cancel() class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): From 04cd2688022ebd07c1f875fefc73f8d15c4ed56c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 11:15:36 -1000 Subject: [PATCH 397/608] Drop AsyncServiceListener (#754) --- examples/async_browser.py | 15 +++++++++------ tests/test_aio.py | 13 ++----------- zeroconf/aio.py | 31 ++++++++++--------------------- 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/examples/async_browser.py b/examples/async_browser.py index cba30223..b835307c 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -10,12 +10,12 @@ import logging from typing import Any, Optional, cast -from zeroconf import IPVersion, ServiceStateChange -from zeroconf.aio import AsyncServiceBrowser, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf import IPVersion, ServiceStateChange, Zeroconf +from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes def async_on_service_state_change( - zeroconf: AsyncZeroconf, service_type: str, name: str, state_change: ServiceStateChange + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange ) -> None: print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) if state_change is not ServiceStateChange.Added: @@ -23,8 +23,9 @@ def async_on_service_state_change( asyncio.ensure_future(async_display_service_info(zeroconf, service_type, name)) -async def async_display_service_info(zeroconf: AsyncZeroconf, service_type: str, name: str) -> None: - info = await zeroconf.async_get_service_info(service_type, name) +async def async_display_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None: + info = AsyncServiceInfo(service_type, name) + await info.async_request(zeroconf, 3000) print("Info from zeroconf.get_service_info: %r" % (info)) if info: addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] @@ -59,7 +60,9 @@ async def async_run(self) -> None: ) print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % services) - self.aiobrowser = AsyncServiceBrowser(self.aiozc, services, handlers=[async_on_service_state_change]) + self.aiobrowser = AsyncServiceBrowser( + self.aiozc.zeroconf, services, handlers=[async_on_service_state_change] + ) while True: await asyncio.sleep(1) diff --git a/tests/test_aio.py b/tests/test_aio.py index 6cfbcfb2..327ccc66 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -12,7 +12,7 @@ import pytest -from zeroconf.aio import AsyncServiceInfo, AsyncServiceListener, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered @@ -433,16 +433,7 @@ async def test_async_service_browser() -> None: calls = [] - with pytest.raises(NotImplementedError): - AsyncServiceListener().add_service(aiozc, "_type", "name._type") - - with pytest.raises(NotImplementedError): - AsyncServiceListener().remove_service(aiozc, "_type", "name._type") - - with pytest.raises(NotImplementedError): - AsyncServiceListener().update_service(aiozc, "_type", "name._type") - - class MyListener(AsyncServiceListener): + class MyListener(ServiceListener): def add_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None: calls.append(("add", type, name)) diff --git a/zeroconf/aio.py b/zeroconf/aio.py index e5b6a96a..1f3e4352 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -26,6 +26,7 @@ from ._core import Zeroconf from ._exceptions import NonUniqueNameException +from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.types import ZeroconfServiceTypes @@ -45,22 +46,10 @@ "AsyncZeroconf", "AsyncServiceInfo", "AsyncServiceBrowser", - "AsyncServiceListener", "AsyncZeroconfServiceTypes", ] -class AsyncServiceListener: - def add_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def remove_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - def update_service(self, aiozc: 'AsyncZeroconf', type_: str, name: str) -> None: - raise NotImplementedError() - - class AsyncServiceInfo(ServiceInfo): """An async version of ServiceInfo.""" @@ -74,15 +63,15 @@ class AsyncServiceBrowser(_ServiceBrowserBase): def __init__( self, - aiozc: 'AsyncZeroconf', + zeroconf: 'Zeroconf', type_: Union[str, list], - handlers: Optional[Union[AsyncServiceListener, List[Callable[..., None]]]] = None, - listener: Optional[AsyncServiceListener] = None, + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, ) -> None: - super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore + super().__init__(zeroconf, type_, handlers, listener, addr, port, delay) self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) async def async_cancel(self) -> None: @@ -115,7 +104,7 @@ async def async_find( local_zc = aiozc or AsyncZeroconf(interfaces=interfaces, ip_version=ip_version) listener = cls() async_browser = AsyncServiceBrowser( - local_zc, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener # type: ignore + local_zc.zeroconf, _SERVICE_TYPE_ENUMERATION_NAME, listener=listener ) # wait for responses @@ -168,7 +157,7 @@ def __init__( ip_version=ip_version, apple_p2p=apple_p2p, ) - self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {} + self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" @@ -269,14 +258,14 @@ async def async_get_service_info( return info return None - async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None: + async def async_add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object will then have its add_service and remove_service methods called when services of that type become available and unavailable.""" await self.async_remove_service_listener(listener) - self.async_browsers[listener] = AsyncServiceBrowser(self, type_, listener) + self.async_browsers[listener] = AsyncServiceBrowser(self.zeroconf, type_, listener) - async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None: + async def async_remove_service_listener(self, listener: ServiceListener) -> None: """Removes a listener from the set that is currently listening.""" if listener in self.async_browsers: await self.async_browsers[listener].async_cancel() From f53c88b52ed080c80e2e98d3da91a830f0c7ebca Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 11:19:04 -1000 Subject: [PATCH 398/608] Revert: Fix thread safety in _ServiceBrowser.update_records_complete (#708) (#755) - This guarding is no longer needed as the ServiceBrowser loop now runs in the event loop and the thread safety guard is no longer needed --- zeroconf/_services/browser.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 1e44679f..853c4395 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -296,15 +296,8 @@ def async_update_records_complete(self) -> None: This method will be run in the event loop. """ - # Cannot use .update here since can fail with - # RuntimeError: dictionary changed size during iteration - # for threaded ServiceBrowsers - while self._pending_handlers: - try: - (name_type, state_change) = self._pending_handlers.popitem(False) - except KeyError: - return - self._handlers_to_call[name_type] = state_change + self._handlers_to_call.update(self._pending_handlers) + self._pending_handlers.clear() def cancel(self) -> None: """Cancel the browser.""" From f24ebba9ecc4d1626d570956a7cc735206d7ff6e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 14:16:08 -1000 Subject: [PATCH 399/608] Simplify ServiceBrowser callsbacks (#756) --- zeroconf/_services/browser.py | 91 ++++++++++++++--------------------- 1 file changed, 37 insertions(+), 54 deletions(-) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 853c4395..b03003b1 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -38,9 +38,9 @@ Signal, SignalRegistrationInterface, ) -from .._utils.aio import get_best_available_queue, get_running_loop, wait_condition_or_timeout +from .._utils.aio import get_best_available_queue, get_running_loop from .._utils.name import service_type_name -from .._utils.time import current_time_millis, millis_to_seconds +from .._utils.time import current_time_millis from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -172,8 +172,8 @@ def __init__( self, zc: 'Zeroconf', type_: Union[str, list], - handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, - listener: Optional['ServiceListener'] = None, + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, @@ -193,7 +193,6 @@ def __init__( self._next_time = {check_type_: current_time for check_type_ in self.types} self._delay = {check_type_: delay for check_type_ in self.types} self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() - self._handlers_to_call: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._service_state_changed = Signal() self.queue: Optional[queue.Queue] = None self.done = False @@ -296,8 +295,30 @@ def async_update_records_complete(self) -> None: This method will be run in the event loop. """ - self._handlers_to_call.update(self._pending_handlers) - self._pending_handlers.clear() + while self._pending_handlers: + event = self._pending_handlers.popitem(False) + # If there is a queue running (ServiceBrowser) + # get fired in dedicated thread + if self.queue: + self.queue.put(event) + else: + self._fire_service_state_changed_event(event) + + def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], ServiceStateChange]) -> None: + """Fire a service state changed event. + + When running with ServiceBrowser, this will happen in the dedicated + thread. + + When running with AsyncServiceBrowser, this will happen in the event loop. + """ + name_type, state_change = event + self._service_state_changed.fire( + zeroconf=self.zc, + service_type=name_type[1], + name=name_type[0], + state_change=state_change, + ) def cancel(self) -> None: """Cancel the browser.""" @@ -307,8 +328,7 @@ def cancel(self) -> None: def generate_ready_queries(self) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" now = current_time_millis() - - if min(self._next_time.values()) > now: + if self._millis_to_wait(current_time_millis()): return [] ready_types = [] @@ -323,56 +343,25 @@ def generate_ready_queries(self) -> List[DNSOutgoing]: return generate_service_query(self.zc, now, ready_types, self.multicast) - def _seconds_to_wait(self) -> Optional[float]: - """Returns the number of seconds to wait for the next event.""" - # If there are handlers to call - # we want to process them right away - if self._handlers_to_call: - return None - + def _millis_to_wait(self, now: float) -> Optional[float]: + """Returns the number of milliseconds to wait for the next event.""" # Wait for the type has the smallest next time next_time = min(self._next_time.values()) - now = current_time_millis() - - if next_time <= now: - return None - - return millis_to_seconds(next_time - now) + return None if next_time <= now else next_time - now async def async_browser_task(self) -> None: """Run the browser task.""" await self.zc.async_wait_for_start() assert self.zc.async_condition is not None while True: - timeout = self._seconds_to_wait() + timeout = self._millis_to_wait(current_time_millis()) if timeout: - async with self.zc.async_condition: - # We must check again while holding the condition - # in case the other thread has added to _handlers_to_call - # between when we checked above when we were not - # holding the condition - if not self._handlers_to_call: - await wait_condition_or_timeout(self.zc.async_condition, timeout) + await self.zc.async_wait(timeout) outs = self.generate_ready_queries() for out in outs: self.zc.async_send(out, addr=self.addr, port=self.port) - if not self._handlers_to_call: - continue - - (name_type, state_change) = self._handlers_to_call.popitem(False) - if self.queue: - self.queue.put((name_type, state_change)) - continue - - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) - async def _async_cancel_browser(self) -> None: """Cancel the browser.""" assert self._browser_task is not None @@ -392,8 +381,8 @@ def __init__( self, zc: 'Zeroconf', type_: Union[str, list], - handlers: Optional[Union['ServiceListener', List[Callable[..., None]]]] = None, - listener: Optional['ServiceListener'] = None, + handlers: Optional[Union[ServiceListener, List[Callable[..., None]]]] = None, + listener: Optional[ServiceListener] = None, addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, @@ -438,10 +427,4 @@ def run(self) -> None: event = self.queue.get() if event is None: return - name_type, state_change = event - self._service_state_changed.fire( - zeroconf=self.zc, - service_type=name_type[1], - name=name_type[0], - state_change=state_change, - ) + self._fire_service_state_changed_event(event) From 1c93baa486b1b0f44487891766e0a0c1de3eb252 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 14:19:18 -1000 Subject: [PATCH 400/608] Update changelog (#757) --- README.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.rst b/README.rst index 3b7dbf22..a26981b4 100644 --- a/README.rst +++ b/README.rst @@ -254,6 +254,28 @@ you can likely not be concerned with the breaking changes below: * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Simplify ServiceBrowser callsbacks (#756) @bdraco + +* Revert: Fix thread safety in _ServiceBrowser.update_records_complete (#708) (#755) @bdraco + +- This guarding is no longer needed as the ServiceBrowser loop + now runs in the event loop and the thread safety guard is no + longer needed + +* Drop AsyncServiceListener (#754) @bdraco (Never shipped) + +* Run ServiceBrowser queries in the event loop (#752) @bdraco + +* Remove unused argument from AsyncZeroconf (#751) @bdraco + +* Fix warning about Zeroconf._async_notify_all not being awaited in sync shutdown (#750) @bdraco + +* Update async_service_info_request example to ensure it runs in the right event loop (#749) @bdraco + +* Run ServiceInfo requests in the event loop (#748) @bdraco + +* Remove support for notify listeners (#733) @bdraco (Never shipped) + * Relocate service browser tests to tests/services/test_browser.py (#745) @bdraco * Relocate ServiceInfo to zeroconf._services.info (#741) @bdraco From 9f68fc8b1b834d0194e8ba1069d052aa853a8d38 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 14:38:41 -1000 Subject: [PATCH 401/608] Add missing coverage for SignalRegistrationInterface (#758) --- tests/test_services.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_services.py b/tests/test_services.py index 1a3ada23..684266f2 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -275,3 +275,18 @@ class MyPartialListener(r.ServiceListener): ) zc.close() + + +def test_signal_registration_interface(): + """Test adding and removing from the SignalRegistrationInterface.""" + + interface = r.SignalRegistrationInterface([]) + + def dummy(): + pass + + interface.register_handler(dummy) + interface.unregister_handler(dummy) + + with pytest.raises(ValueError): + interface.unregister_handler(dummy) From 936500a47cc33d9daa86f9012b1791986361ff63 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 16:40:08 -1000 Subject: [PATCH 402/608] Add 60s timeout for each test (#761) --- Makefile | 4 ++-- requirements-dev.txt | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index de816c1c..6602d808 100644 --- a/Makefile +++ b/Makefile @@ -41,10 +41,10 @@ mypy: mypy --no-warn-redundant-casts --no-warn-unused-ignores examples/*.py zeroconf test: - pytest -v tests + pytest --timeout=60 -v tests test_coverage: - pytest -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests + pytest --timeout=60 -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf diff --git a/requirements-dev.txt b/requirements-dev.txt index 325ce32a..eef93254 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,3 +13,4 @@ pylint pytest pytest-asyncio pytest-cov +pytest-timeout From fc0e599eec77477dd8f21ecd68b238e6a27f1bcf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 18:21:38 -1000 Subject: [PATCH 403/608] Fix race condition in ServiceBrowser test_integration (#762) - The event was being cleared in the wrong thread which meant if the test was fast enough it would not be seen the second time and give a spurious failure --- tests/services/test_browser.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 342ced69..2d47e368 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -441,7 +441,6 @@ def current_time_millis(): return time.time() * 1000 + time_offset * 1000 expected_ttl = const._DNS_HOST_TTL - nbr_answers = 0 def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): @@ -454,6 +453,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): unexpected_ttl.set() got_query.set() + got_query.clear() + old_send(out, addr=addr, port=port) # patch the zeroconf send @@ -482,13 +483,13 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # is greater than half the original TTL sleep_count = 0 test_iterations = 50 + while nbr_answers < test_iterations: # Increase simulated time shift by 1/4 of the TTL in seconds time_offset += expected_ttl / 4 zeroconf_browser.notify_all() sleep_count += 1 got_query.wait(0.1) - got_query.clear() # Prevent the test running indefinitely in an error condition assert sleep_count < test_iterations * 4 assert not unexpected_ttl.is_set() From 38b59a64592f41b2bb547b35c72a010a925a2941 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 18:28:24 -1000 Subject: [PATCH 404/608] Fix test_lots_of_names overflowing the incoming buffer (#763) --- tests/test_init.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index 30149b84..dd330d39 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -14,6 +14,8 @@ import zeroconf as r from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, const +from . import _inject_response + log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -153,7 +155,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): # force receive on oversized packet zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) - time.sleep(2.0) + time.sleep(0.3) zeroconf.log.debug( 'warn %d debug %d was %s', mocked_log_warn.call_count, @@ -162,8 +164,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): ) assert mocked_log_debug.call_count > call_counts[0] - # close our zeroconf which will close the sockets - zc.close() + # close our zeroconf which will close the sockets + zc.close() def verify_name_change(self, zc, type_, name, number_hosts): desc = {'path': '/~paulsm/'} @@ -201,17 +203,11 @@ def verify_name_change(self, zc, type_, name, number_hosts): assert info_service2.name.split('.')[0] == '%s-%d' % (name, number_hosts + 1) def generate_many_hosts(self, zc, type_, name, number_hosts): - records_per_server = 2 block_size = 25 number_hosts = int(((number_hosts - 1) / block_size + 1)) * block_size for i in range(1, number_hosts + 1): next_name = name if i == 1 else '%s-%d' % (name, i) self.generate_host(zc, next_name, type_) - if i % block_size == 0: - sleep_count = 0 - while sleep_count < 40 and i * records_per_server > len(zc.cache.entries_with_name(type_)): - sleep_count += 1 - time.sleep(0.05) @staticmethod def generate_host(zc, host_name, type_): @@ -233,4 +229,4 @@ def generate_host(zc, host_name, type_): ), 0, ) - zc.send(out) + _inject_response(zc, r.DNSIncoming(out.packets()[0])) From 85532e13e42447fcd6d4d4b0060f04d33c3ab780 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 19:00:19 -1000 Subject: [PATCH 405/608] Break test_lots_of_names into two tests (#764) --- tests/test_init.py | 70 +++++++++------------------------------------- 1 file changed, 13 insertions(+), 57 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index dd330d39..0cc3baf8 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -12,7 +12,7 @@ from typing import Optional # noqa # used in type hints import zeroconf as r -from zeroconf import ServiceBrowser, ServiceInfo, Zeroconf, const +from zeroconf import DNSOutgoing, ServiceBrowser, ServiceInfo, Zeroconf, const from . import _inject_response @@ -69,8 +69,7 @@ def test_same_name(self): generated.add_question(question) r.DNSIncoming(generated.packets()[0]) - def test_lots_of_names(self): - + def test_verify_name_change_with_lots_of_names(self): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) @@ -83,64 +82,21 @@ def test_lots_of_names(self): # verify that name changing works self.verify_name_change(zc, type_, name, server_count) - # we are going to patch the zeroconf send to check packet sizes - old_send = zc.async_send - - longest_packet_len = 0 - longest_packet = None # type: Optional[r.DNSOutgoing] - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): - """Sends an outgoing packet.""" - for packet in out.packets(): - nonlocal longest_packet_len, longest_packet - if longest_packet_len < len(packet): - longest_packet_len = len(packet) - longest_packet = out - old_send(out, addr=addr, port=port) - - # patch the zeroconf send - with unittest.mock.patch.object(zc, "async_send", send): - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - # start a browser - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - # wait until the browse request packet has maxed out in size - sleep_count = 0 - # we will never get to this large of a packet given the application-layer - # splitting of packets, but we still want to track the longest_packet_len - # for the debug message below - while sleep_count < 100 and longest_packet_len < const._MAX_MSG_ABSOLUTE - 100: - sleep_count += 1 - time.sleep(0.1) - - browser.cancel() - time.sleep(0.5) - - import zeroconf - - zeroconf.log.debug('sleep_count %d, sized %d', sleep_count, longest_packet_len) - - # now the browser has sent at least one request, verify the size - assert longest_packet_len <= const._MAX_MSG_TYPICAL - assert longest_packet_len >= const._MAX_MSG_TYPICAL - 100 + zc.close() - # mock zeroconf's logger warning() and debug() - from unittest.mock import patch + def test_large_packet_exception_log_handling(self): + """Verify we downgrade debug after warning.""" - patch_warn = patch('zeroconf._logger.log.warning') - patch_debug = patch('zeroconf._logger.log.debug') - mocked_log_warn = patch_warn.start() - mocked_log_debug = patch_debug.start() + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + with unittest.mock.patch('zeroconf._logger.log.warning') as mocked_log_warn, unittest.mock.patch( + 'zeroconf._logger.log.debug' + ) as mocked_log_debug: # now that we have a long packet in our possession, let's verify the # exception handling. - out = longest_packet - assert out is not None - out.data.append(b'\0' * 1000) + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) + out.data.append(b'\0' * 10000) # mock the zeroconf logger and check for the correct logging backoff call_counts = mocked_log_warn.call_count, mocked_log_debug.call_count @@ -156,7 +112,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) zc.send(out, const._MDNS_ADDR, const._MDNS_PORT) time.sleep(0.3) - zeroconf.log.debug( + r.log.debug( 'warn %d debug %d was %s', mocked_log_warn.call_count, mocked_log_debug.call_count, From 6c82fa9efd0f434f0f7c83e3bd98bd7851ede4cf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 19:10:58 -1000 Subject: [PATCH 406/608] Switch to using an asyncio.Event for async_wait (#759) - We no longer need to check for thread safety under a asyncio.Condition as the ServiceBrowser and ServiceInfo internals schedule coroutines in the eventloop. --- tests/services/test_browser.py | 87 +++++++++++++++++++++++++++++++++- tests/utils/test_aio.py | 18 +++---- zeroconf/_core.py | 23 ++++----- zeroconf/_services/browser.py | 35 ++++++++------ zeroconf/_services/info.py | 2 +- zeroconf/_utils/aio.py | 14 +++--- zeroconf/aio.py | 3 +- 7 files changed, 131 insertions(+), 51 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 2d47e368..5f2cea21 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -39,6 +39,91 @@ def teardown_module(): log.setLevel(original_logging_level) +def test_service_browser_cancel_multiple_times(): + """Test we can cancel a ServiceBrowser multiple times before close.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + browser.cancel() + browser.cancel() + browser.cancel() + + zc.close() + + +def test_service_browser_cancel_multiple_times_after_close(): + """Test we can cancel a ServiceBrowser multiple times after close.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + browser = r.ServiceBrowser(zc, type_, None, listener) + + zc.close() + + browser.cancel() + browser.cancel() + browser.cancel() + + +def test_service_browser_started_after_zeroconf_closed(): + """Test starting a ServiceBrowser after close raises RuntimeError.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + # start a browser + type_ = "_hap._tcp.local." + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + zc.close() + + with pytest.raises(RuntimeError): + browser = r.ServiceBrowser(zc, type_, None, listener) + + +def test_multiple_instances_running_close(): + """Test we can shutdown multiple instances.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + zc2 = Zeroconf(interfaces=['127.0.0.1']) + zc3 = Zeroconf(interfaces=['127.0.0.1']) + + assert zc.loop != zc2.loop + assert zc.loop != zc3.loop + + class MyServiceListener(r.ServiceListener): + pass + + listener = MyServiceListener() + + zc2.add_service_listener("zca._hap._tcp.local.", listener) + + zc.close() + zc2.remove_service_listener(listener) + zc2.close() + zc3.close() + + class TestServiceBrowser(unittest.TestCase): def test_update_record(self): enable_ipv6 = has_working_ipv6() and not os.environ.get('SKIP_IPV6') @@ -489,7 +574,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): time_offset += expected_ttl / 4 zeroconf_browser.notify_all() sleep_count += 1 - got_query.wait(0.1) + got_query.wait(0.5) # Prevent the test running indefinitely in an error condition assert sleep_count < test_iterations * 4 assert not unexpected_ttl.is_set() diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index 1f0a1d7e..bd402d4b 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -24,22 +24,16 @@ def test_get_running_loop_no_loop() -> None: @pytest.mark.asyncio -async def test_wait_condition_or_timeout_times_out() -> None: - """Test wait_condition_or_timeout will timeout.""" - test_cond = asyncio.Condition() - async with test_cond: - await aioutils.wait_condition_or_timeout(test_cond, 0.1) +async def test_wait_event_or_timeout_times_out() -> None: + """Test wait_event_or_timeout will timeout.""" + test_event = asyncio.Event() + await aioutils.wait_event_or_timeout(test_event, 0.1) - async def _hold_condition(): - async with test_cond: - await test_cond.wait() - - task = asyncio.ensure_future(_hold_condition()) + task = asyncio.ensure_future(test_event.wait()) await asyncio.sleep(0.1) async def _async_wait_or_timeout(): - async with test_cond: - await aioutils.wait_condition_or_timeout(test_cond, 0.1) + await aioutils.wait_event_or_timeout(test_event, 0.1) # Test high lock contention await asyncio.gather(*[_async_wait_or_timeout() for _ in range(100)]) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 9dab6a27..1053103c 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -41,7 +41,7 @@ from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry -from ._utils.aio import get_running_loop, shutdown_loop, wait_condition_or_timeout +from ._utils.aio import get_running_loop, shutdown_loop, wait_event_or_timeout from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -291,7 +291,7 @@ def __init__( self.query_handler = QueryHandler(self.registry, self.cache) self.record_manager = RecordManager(self) - self.async_condition: Optional[asyncio.Condition] = None + self.notify_event: Optional[asyncio.Event] = None self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None @@ -304,7 +304,7 @@ def start(self) -> None: """Start Zeroconf.""" self.loop = get_running_loop() if self.loop: - self.async_condition = asyncio.Condition() + self.notify_event = asyncio.Event() self.engine.setup(self.loop, None) return self._start_thread() @@ -316,7 +316,7 @@ def _start_thread(self) -> None: def _run_loop() -> None: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.async_condition = asyncio.Condition() + self.notify_event = asyncio.Event() self.engine.setup(self.loop, loop_thread_ready) self.loop.run_forever() @@ -343,9 +343,8 @@ def wait(self, timeout: float) -> None: async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" - assert self.async_condition is not None - async with self.async_condition: - await wait_condition_or_timeout(self.async_condition, millis_to_seconds(timeout)) + assert self.notify_event is not None + await wait_event_or_timeout(self.notify_event, timeout=millis_to_seconds(timeout)) def notify_all(self) -> None: """Notifies all waiting threads and notify listeners.""" @@ -354,13 +353,9 @@ def notify_all(self) -> None: def async_notify_all(self) -> None: """Schedule an async_notify_all.""" - asyncio.ensure_future(self._async_notify_all()) - - async def _async_notify_all(self) -> None: - """Notify all async listeners.""" - assert self.async_condition is not None - async with self.async_condition: - self.async_condition.notify_all() + assert self.notify_event is not None + self.notify_event.set() + self.notify_event.clear() def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: """Returns network's service information for a particular diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index b03003b1..296662d6 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -21,6 +21,7 @@ """ import asyncio +import concurrent.futures import contextlib import queue import threading @@ -352,7 +353,6 @@ def _millis_to_wait(self, now: float) -> Optional[float]: async def async_browser_task(self) -> None: """Run the browser task.""" await self.zc.async_wait_for_start() - assert self.zc.async_condition is not None while True: timeout = self._millis_to_wait(current_time_millis()) if timeout: @@ -366,8 +366,9 @@ async def _async_cancel_browser(self) -> None: """Cancel the browser.""" assert self._browser_task is not None self._browser_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._browser_task + browser_task = self._browser_task + self._browser_task = None + await browser_task class ServiceBrowser(_ServiceBrowserBase, threading.Thread): @@ -391,19 +392,21 @@ def __init__( super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) self.queue = get_best_available_queue() self.daemon = True + assert self.zc.loop is not None + if get_running_loop() == self.zc.loop: + self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + else: + if not self.zc.loop.is_running(): + raise RuntimeError("The event loop is not running") + self._browser_task = cast( + asyncio.Task, + asyncio.run_coroutine_threadsafe(self._async_browser_task(), self.zc.loop).result(), + ) self.start() self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), ) - assert self.zc.loop is not None - if get_running_loop() == self.zc.loop: - self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) - return - self._browser_task = cast( - asyncio.Task, - asyncio.run_coroutine_threadsafe(self._async_browser_task(), self.zc.loop).result(), - ) async def _async_browser_task(self) -> asyncio.Task: return cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) @@ -413,10 +416,12 @@ def cancel(self) -> None: assert self.zc.loop is not None assert self.queue is not None self.queue.put(None) - if get_running_loop() == self.zc.loop: - asyncio.ensure_future(self._async_cancel_browser()) - else: - asyncio.run_coroutine_threadsafe(self._async_cancel_browser(), self.zc.loop).result() + if self._browser_task: + if get_running_loop() == self.zc.loop: + asyncio.ensure_future(self._async_cancel_browser()) + elif self.zc.loop.is_running(): + with contextlib.suppress(asyncio.CancelledError, concurrent.futures.CancelledError): + asyncio.run_coroutine_threadsafe(self._async_cancel_browser(), self.zc.loop).result() super().cancel() self.join() diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 5a331927..aa457ffd 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -397,7 +397,7 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: """Returns true if the service could be discovered on the network, and updates this object with details discovered. """ - assert zc.loop is not None + assert zc.loop is not None and zc.loop.is_running() return asyncio.run_coroutine_threadsafe(self.async_request(zc, timeout), zc.loop).result() async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index d4adffa8..726acea4 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -34,8 +34,8 @@ def get_best_available_queue() -> queue.Queue: # Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed -async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None: - """Wait for a condition or timeout.""" +async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: + """Wait for an event or timeout.""" loop = asyncio.get_event_loop() future = loop.create_future() @@ -44,22 +44,22 @@ def _handle_timeout() -> None: future.set_result(None) timer_handle = loop.call_later(timeout, _handle_timeout) - condition_wait = loop.create_task(condition.wait()) + event_wait = loop.create_task(event.wait()) def _handle_wait_complete(_: asyncio.Task) -> None: if not future.done(): future.set_result(None) - condition_wait.add_done_callback(_handle_wait_complete) + event_wait.add_done_callback(_handle_wait_complete) try: await future finally: timer_handle.cancel() - if not condition_wait.done(): - condition_wait.cancel() + if not event_wait.done(): + event_wait.cancel() with contextlib.suppress(asyncio.CancelledError): - await condition_wait + await event_wait async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 1f3e4352..7a107fcc 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -76,7 +76,8 @@ def __init__( async def async_cancel(self) -> None: """Cancel the browser.""" - await self._async_cancel_browser() + with contextlib.suppress(asyncio.CancelledError): + await self._async_cancel_browser() super().cancel() From e70431e1fdc92c155309a1d40c89fed48737970c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 19 Jun 2021 22:10:28 -1000 Subject: [PATCH 407/608] Add test coverage to ensure RecordManager.add_listener callsback known question answers (#767) --- tests/test_handlers.py | 45 ++++++++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 2 +- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index b86d253d..36f25296 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -11,6 +11,7 @@ import time import unittest import unittest.mock +from typing import List import zeroconf as r from zeroconf import ServiceInfo, Zeroconf, current_time_millis @@ -915,3 +916,47 @@ async def test_cache_flush_bit(): assert loaded_info.addresses == info.addresses await aiozc.async_close() + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_record_update_manager_add_listener_callsback_existing_records(): + """Test that the RecordUpdateManager will callback existing records.""" + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc: Zeroconf = aiozc.zeroconf + updated = [] + + class MyListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.DNSRecord]) -> None: + """Update multiple records in one shot.""" + updated.extend(records) + + type_ = "_cacheflush._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "server-uu1.local." + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] + ) + a_record = info.dns_addresses()[0] + ptr_record = info.dns_pointer() + zc.cache.async_add_records([ptr_record, a_record, info.dns_text(), info.dns_service()]) + + listener = MyListener() + + zc.add_listener( + listener, + [ + r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN), + r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN), + ], + ) + await asyncio.sleep(0) # flush out the call_soon_threadsafe + + assert set(updated) == set([ptr_record, a_record]) + await aiozc.async_close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 2a58850d..08c25e17 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -386,7 +386,6 @@ def add_listener( self.listeners.append(listener) if question is None: - self.zc.notify_all() return questions = [question] if isinstance(question, DNSQuestion) else question @@ -406,6 +405,7 @@ def _async_update_matching_records( for record in self.cache.async_entries_with_name(question.name): if not record.is_expired(now) and question.answered_by(record): records.append(record) + if not records: return listener.async_update_records(self.zc, now, records) From 5d44a36a59c21ef7869ba9e6dde9f658d3502793 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 08:26:55 -1000 Subject: [PATCH 408/608] Improve performance of parsing DNSIncoming by caching read_utf (#769) --- zeroconf/_protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 80ca7b88..b0b87a06 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -22,6 +22,7 @@ import enum import struct +from functools import lru_cache from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast @@ -207,6 +208,7 @@ def read_others(self) -> None: if rec is not None: self.answers.append(rec) + @lru_cache(maxsize=None) def read_utf(self, offset: int, length: int) -> str: """Reads a UTF-8 string of a given length from the packet""" return str(self.data[offset : offset + length], 'utf-8', 'replace') From b600547a47878775e1c6fb8df46682a670beccba Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 11:13:33 -1000 Subject: [PATCH 409/608] Implement accidental synchronization protection (RFC2762 section 5.2) (#773) --- tests/services/test_browser.py | 46 ++++++++++++++++++++++++++++++++-- zeroconf/_services/browser.py | 26 ++++++++++++++++--- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 5f2cea21..df405ac2 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -14,7 +14,7 @@ import pytest import zeroconf as r -from zeroconf import DNSPointer, DNSQuestion, const, current_time_millis +from zeroconf import DNSPointer, DNSQuestion, const, current_time_millis, millis_to_seconds import zeroconf._services.browser as _services_browser from zeroconf import Zeroconf from zeroconf._services import ServiceStateChange @@ -453,7 +453,11 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the backoff limit to prevent test running forever with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch.object( _services_browser, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", 10): + ), unittest.mock.patch.object( + _services_browser, "_BROWSER_BACKOFF_LIMIT", 10 + ), unittest.mock.patch.object( + _services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0) + ): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass @@ -498,6 +502,44 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser.close() +def test_first_query_delay(): + """Verify the first query is delayed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + """ + type_ = "_http._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_query_time = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_query_time + if first_query_time is None: + first_query_time = current_time_millis() + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + start_time = current_time_millis() + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert ( + current_time_millis() - start_time > _services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[0] + ) + finally: + browser.cancel() + zeroconf_browser.close() + + def test_integration(): service_added = Event() service_removed = Event() diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 296662d6..964998a2 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -24,6 +24,7 @@ import concurrent.futures import contextlib import queue +import random import threading import warnings from collections import OrderedDict @@ -41,7 +42,7 @@ ) from .._utils.aio import get_best_available_queue, get_running_loop from .._utils.name import service_type_name -from .._utils.time import current_time_millis +from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _BROWSER_BACKOFF_LIMIT, _BROWSER_TIME, @@ -56,6 +57,8 @@ _TYPE_PTR, ) +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +_FIRST_QUERY_DELAY_RANDOM_INTERVAL = (20, 120) # ms if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 @@ -190,14 +193,15 @@ def __init__( self.addr = addr self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) - current_time = current_time_millis() - self._next_time = {check_type_: current_time for check_type_ in self.types} - self._delay = {check_type_: delay for check_type_ in self.types} + self._next_time: Dict[str, float] = {} + self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self.types} self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._service_state_changed = Signal() self.queue: Optional[queue.Queue] = None self.done = False + self._generate_first_next_time() + if hasattr(handlers, 'add_service'): listener = cast('ServiceListener', handlers) handlers = None @@ -212,6 +216,20 @@ def __init__( self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + def _generate_first_next_time(self) -> None: + """Generate the initial next query times. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + To avoid accidental synchronization when, for some reason, multiple + clients begin querying at exactly the same moment (e.g., because of + some common external trigger event), a Multicast DNS querier SHOULD + also delay the first query of the series by a randomly chosen amount + in the range 20-120 ms. + """ + delay = millis_to_seconds(random.randint(*_FIRST_QUERY_DELAY_RANDOM_INTERVAL)) + next_time = current_time_millis() + delay + self._next_time = {check_type_: next_time for check_type_ in self.types} + @property def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface From f23df4f5f05e3911cbf96234b198ea88691aadad Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 12:13:45 -1000 Subject: [PATCH 410/608] Verify async callers can still use Zeroconf without migrating to AsyncZeroconf (#775) --- tests/test_core.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_core.py b/tests/test_core.py index 13c6ec70..87979884 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -569,3 +569,22 @@ async def make_query(): # unregister zc.registry.remove(info) zc.close() + + +@pytest.mark.asyncio +async def test_open_close_twice_from_async() -> None: + """Test we can close twice from a coroutine when using Zeroconf. + + Ideally callers switch to using AsyncZeroconf, however there will + be a peroid where they still call the sync wrapper that we want + to ensure will not deadlock on shutdown. + + This test is expected to throw warnings about tasks being destroyed + since we force shutdown right away since we don't want to block + callers event loops and since they aren't using the AsyncZeroconf + version they won't yield with an await like async_close we don't + have much choice but to force things down. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + zc.close() + zc.close() From e8836b134c47080edaf47532d7cb844b307dfb08 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 12:24:26 -1000 Subject: [PATCH 411/608] Add a guard against the task list changing when shutting down (#776) --- zeroconf/_utils/aio.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index 726acea4..fcf71c34 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -23,7 +23,7 @@ import asyncio import contextlib import queue -from typing import Optional, Set, cast +from typing import List, Optional, Set, cast def get_best_available_queue() -> queue.Queue: @@ -62,12 +62,13 @@ def _handle_wait_complete(_: asyncio.Task) -> None: await event_wait -async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> Set[asyncio.Task]: +async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]: """Return all tasks running.""" await asyncio.sleep(0) # flush out any call_soon_threadsafe + # Make a copy of the tasks in case they change during iteration if hasattr(asyncio, 'all_tasks'): - return cast(Set[asyncio.Task], asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member - return cast(Set[asyncio.Task], asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member + return list(asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member + return list(asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: @@ -77,7 +78,7 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" - pending_tasks = asyncio.run_coroutine_threadsafe(_get_all_tasks(loop), loop).result() + pending_tasks = set(asyncio.run_coroutine_threadsafe(_get_all_tasks(loop), loop).result()) done_tasks = set(task for task in pending_tasks if not task.done()) pending_tasks -= done_tasks if pending_tasks: From b5d54e485d9dbcde1b7b472760a0b307198b8ec8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 12:35:53 -1000 Subject: [PATCH 412/608] Fix deadlock on ServiceBrowser shutdown with PyPy (#774) --- zeroconf/_services/browser.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 964998a2..0bc92686 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -21,7 +21,6 @@ """ import asyncio -import concurrent.futures import contextlib import queue import random @@ -40,7 +39,7 @@ Signal, SignalRegistrationInterface, ) -from .._utils.aio import get_best_available_queue, get_running_loop +from .._utils.aio import get_best_available_queue from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( @@ -386,7 +385,8 @@ async def _async_cancel_browser(self) -> None: self._browser_task.cancel() browser_task = self._browser_task self._browser_task = None - await browser_task + with contextlib.suppress(asyncio.CancelledError): + await browser_task class ServiceBrowser(_ServiceBrowserBase, threading.Thread): @@ -411,35 +411,30 @@ def __init__( self.queue = get_best_available_queue() self.daemon = True assert self.zc.loop is not None - if get_running_loop() == self.zc.loop: - self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) - else: - if not self.zc.loop.is_running(): - raise RuntimeError("The event loop is not running") - self._browser_task = cast( - asyncio.Task, - asyncio.run_coroutine_threadsafe(self._async_browser_task(), self.zc.loop).result(), - ) + if not self.zc.loop.is_running(): + raise RuntimeError("The event loop is not running") + self.zc.loop.call_soon_threadsafe(self._async_start_browser) self.start() self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), ) - async def _async_browser_task(self) -> asyncio.Task: - return cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + def _async_start_browser(self) -> None: + """Start the browser from the event loop.""" + self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + + def _async_cancel_browser_soon(self) -> None: + """Cancel the browser from the event loop.""" + if self._browser_task: + asyncio.ensure_future(self._async_cancel_browser()) def cancel(self) -> None: """Cancel the browser.""" assert self.zc.loop is not None assert self.queue is not None self.queue.put(None) - if self._browser_task: - if get_running_loop() == self.zc.loop: - asyncio.ensure_future(self._async_cancel_browser()) - elif self.zc.loop.is_running(): - with contextlib.suppress(asyncio.CancelledError, concurrent.futures.CancelledError): - asyncio.run_coroutine_threadsafe(self._async_cancel_browser(), self.zc.loop).result() + self.zc.loop.call_soon_threadsafe(self._async_cancel_browser_soon) super().cancel() self.join() From c0f4f48e2bb996ce18cb569aa5369356cbc919ff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 12:41:00 -1000 Subject: [PATCH 413/608] Implement duplicate question supression (#770) https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 --- tests/__init__.py | 1 + tests/services/test_browser.py | 46 ++++++++++++++++++++- tests/test_core.py | 16 ++++++++ tests/test_handlers.py | 50 +++++++++++++++++++++++ tests/test_history.py | 75 ++++++++++++++++++++++++++++++++++ zeroconf/_core.py | 5 ++- zeroconf/_dns.py | 17 ++++---- zeroconf/_handlers.py | 10 +++-- zeroconf/_history.py | 70 +++++++++++++++++++++++++++++++ zeroconf/_services/browser.py | 8 +++- zeroconf/const.py | 1 + 11 files changed, 285 insertions(+), 14 deletions(-) create mode 100644 tests/test_history.py create mode 100644 zeroconf/_history.py diff --git a/tests/__init__.py b/tests/__init__.py index 86d7e199..0e7aa930 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -65,3 +65,4 @@ def has_working_ipv6(): def _clear_cache(zc): zc.cache.cache.clear() + zc.question_history._history.clear() diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index df405ac2..69c23cb0 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -20,6 +20,7 @@ from zeroconf._services import ServiceStateChange from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo +from zeroconf.aio import AsyncZeroconf from .. import has_working_ipv6, _inject_response @@ -426,7 +427,8 @@ def _mock_get_expiration_time(self, percent): zeroconf.close() -def test_backoff(): +@unittest.mock.patch("zeroconf._core.QuestionHistory.suppresses", return_value=False) +def test_backoff(suppresses_mock): got_query = Event() type_ = "_http._tcp.local." @@ -902,3 +904,45 @@ def test_group_ptr_queries_with_known_answers(): # If we generate multiple packets there must # only be one question assert len(packets) == 1 or len(out.questions) == 1 + + +# This test uses asyncio because it needs to access the cache directly +# which is not threadsafe +@pytest.mark.asyncio +async def test_generate_service_query_suppress_duplicate_questions(): + """Generate a service query for sending with zeroconf.send.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + now = current_time_millis() + name = "_hap._tcp.local." + question = r.DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN) + answer = r.DNSPointer( + name, + const._TYPE_PTR, + const._CLASS_IN, + 10000, + f'known-to-other.{name}', + ) + other_known_answers = set([answer]) + zc.question_history.add_question_at_time(question, now, other_known_answers) + assert zc.question_history.suppresses(question, now, other_known_answers) + + # The known answer list is different, do not suppress + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert outs + + zc.cache.async_add_records([answer]) + # The known answer list contains all the asked questions in the history + # we should suppress + + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert not outs + + # We do not suppress once the question history expires + outs = _services_browser.generate_service_query(zc, now + 1000, [name], multicast=True) + assert outs + + # We do not suppress QU queries ever + outs = _services_browser.generate_service_query(zc, now, [name], multicast=False) + assert outs + await aiozc.async_close() diff --git a/tests/test_core.py b/tests/test_core.py index 87979884..6447bbf3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -50,10 +50,26 @@ async def test_reaper(): record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = set( + [ + r.DNSPointer( + "_hap._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + 10000, + 'known-to-other._hap._tcp.local.', + ) + ] + ) + zeroconf.question_history.add_question_at_time(question, now, other_known_answers) + assert zeroconf.question_history.suppresses(question, now, other_known_answers) entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) await asyncio.sleep(1.2) entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) await aiozc.async_close() + assert not zeroconf.question_history.suppresses(question, now, other_known_answers) assert entries != original_entries assert entries_with_cache != original_entries assert record_with_10s_ttl in entries diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 36f25296..d77db3f0 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -960,3 +960,53 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.DNSRe assert set(updated) == set([ptr_record, a_record]) await aiozc.async_close() + + +def test_questions_query_handler_populates_the_question_history_from_qm_questions(): + zc = Zeroconf(interfaces=['127.0.0.1']) + now = current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question.unicast = False + known_answer = r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + generated.add_question(question) + generated.add_answer_at_time(known_answer, 0) + now = r.current_time_millis() + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is None + assert zc.question_history.suppresses(question, now, set([known_answer])) + + zc.close() + + +def test_questions_query_handler_does_not_put_qu_questions_in_history(): + zc = Zeroconf(interfaces=['127.0.0.1']) + now = current_time_millis() + _clear_cache(zc) + + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + known_answer = r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + generated.add_question(question) + generated.add_answer_at_time(known_answer, 0) + now = r.current_time_millis() + packets = generated.packets() + unicast_out, multicast_out = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) + assert unicast_out is None + assert multicast_out is None + assert not zc.question_history.suppresses(question, now, set([known_answer])) + + zc.close() diff --git a/tests/test_history.py b/tests/test_history.py new file mode 100644 index 00000000..89159dff --- /dev/null +++ b/tests/test_history.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +"""Unit tests for _history.py.""" + +from zeroconf._history import QuestionHistory +import zeroconf as r +import zeroconf.const as const + + +def test_question_suppression(): + history = QuestionHistory() + + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = set( + [ + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + ] + ) + our_known_answers = set( + [ + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-us._hap._tcp.local.' + ) + ] + ) + + history.add_question_at_time(question, now, other_known_answers) + + # Verify the question is suppressed if the known answers are the same + assert history.suppresses(question, now, other_known_answers) + + # Verify the question is suppressed if we know the answer to all the known answers + assert history.suppresses(question, now, other_known_answers | our_known_answers) + + # Verify the question is not suppressed if our known answers do no include the ones in the last question + assert not history.suppresses(question, now, set()) + + # Verify the question is not suppressed if our known answers do no include the ones in the last question + assert not history.suppresses(question, now, our_known_answers) + + # Verify the question is no longer suppressed after 1s + assert not history.suppresses(question, now + 1000, other_known_answers) + + +def test_question_expire(): + history = QuestionHistory() + + question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) + now = r.current_time_millis() + other_known_answers = set( + [ + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + ] + ) + history.add_question_at_time(question, now, other_known_answers) + + # Verify the question is suppressed if the known answers are the same + assert history.suppresses(question, now, other_known_answers) + + history.async_expire(now) + + # Verify the question is suppressed if the known answers are the same since the cache hasn't expired + assert history.suppresses(question, now, other_known_answers) + + history.async_expire(now + 1000) + + # Verify the question not longer suppressed since the cache has expired + assert not history.suppresses(question, now, other_known_answers) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1053103c..83395f05 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -35,6 +35,7 @@ from ._dns import DNSQuestion from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager +from ._history import QuestionHistory from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing from ._services import RecordUpdateListener, ServiceListener @@ -134,6 +135,7 @@ async def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" while not self.zc.done: now = current_time_millis() + self.zc.question_history.async_expire(now) self.zc.record_manager.async_updates(now, self.zc.cache.async_expire(now)) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) @@ -288,7 +290,8 @@ def __init__( self.browsers: Dict[ServiceListener, ServiceBrowser] = {} self.registry = ServiceRegistry() self.cache = DNSCache() - self.query_handler = QueryHandler(self.registry, self.cache) + self.question_history = QuestionHistory() + self.query_handler = QueryHandler(self.registry, self.cache, self.question_history) self.record_manager = RecordManager(self) self.notify_event: Optional[asyncio.Event] = None diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index e656bc51..0fad150d 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -414,18 +414,19 @@ def __init__(self, records: Iterable[DNSRecord]) -> None: self._records = records self._lookup: Optional[Dict[DNSRecord, DNSRecord]] = None - def suppresses(self, record: DNSRecord) -> bool: - """Returns true if any answer in the rrset can suffice for the - information held in this record.""" + @property + def lookup(self) -> Dict[DNSRecord, DNSRecord]: if self._lookup is None: # Build the hash table so we can lookup the record independent of the ttl self._lookup = {record: record for record in self._records} - other = self._lookup.get(record) + return self._lookup + + def suppresses(self, record: DNSRecord) -> bool: + """Returns true if any answer in the rrset can suffice for the + information held in this record.""" + other = self.lookup.get(record) return bool(other and other.ttl > (record.ttl / 2)) def __contains__(self, record: DNSRecord) -> bool: """Returns true if the rrset contains the record.""" - if self._lookup is None: - # Build the hash table so we can lookup the record independent of the ttl - self._lookup = {record: record for record in self._records} - return record in self._lookup + return record in self.lookup diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 08c25e17..79bb6f90 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,10 +21,11 @@ """ import itertools -from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from typing import Dict, Iterable, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from ._cache import DNSCache, _UniqueRecordsType from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord +from ._history import QuestionHistory from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing from ._services import RecordUpdateListener @@ -156,10 +157,11 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: class QueryHandler: """Query the ServiceRegistry.""" - def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None: + def __init__(self, registry: ServiceRegistry, cache: DNSCache, question_history: QuestionHistory) -> None: """Init the query handler.""" self.registry = registry self.cache = cache + self.question_history = question_history def _add_service_type_enumeration_query_answers( self, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float @@ -253,6 +255,8 @@ def async_response( # pylint: disable=unused-argument for msg in msgs: for question in msg.questions: + if not question.unicast: + self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup)) answer_set: _AnswerWithAdditionalsType = {} self._answer_question(question, answer_set, known_answers, msg.now) if not ucast_source and question.unicast: @@ -364,7 +368,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: self.async_updates_complete() def _async_mark_unique_cached_records_older_than_1s_to_expire( - self, unique_types: Set[Tuple[str, int, int]], answers: List[DNSRecord], now: float + self, unique_types: Set[Tuple[str, int, int]], answers: Iterable[DNSRecord], now: float ) -> None: # rfc6762#section-10.2 para 2 # Since unique is set, all old records with that name, rrtype, diff --git a/zeroconf/_history.py b/zeroconf/_history.py new file mode 100644 index 00000000..cbb36144 --- /dev/null +++ b/zeroconf/_history.py @@ -0,0 +1,70 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import Dict, Set, Tuple + +from ._dns import DNSQuestion, DNSRecord +from .const import _DUPLICATE_QUESTION_INTERVAL + +# The QuestionHistory is used to implement Duplicate Question Suppression +# https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + + +class QuestionHistory: + def __init__(self) -> None: + self._history: Dict[DNSQuestion, Tuple[float, Set[DNSRecord]]] = {} + + def add_question_at_time(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> None: + """Remember a question with known answers.""" + self._history[question] = (now, known_answers) + + def suppresses(self, question: DNSQuestion, now: float, known_answers: Set[DNSRecord]) -> bool: + """Check to see if a question should be suppressed. + + https://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + When multiple queriers on the network are querying + for the same resource records, there is no need for them to all be + repeatedly asking the same question. + """ + previous_question = self._history.get(question) + # There was not previous question in the history + if not previous_question: + return False + than, previous_known_answers = previous_question + # The last question was older than 999ms + if now - than > _DUPLICATE_QUESTION_INTERVAL: + return False + # The last question has more known answers than + # we knew so we have to ask + if previous_known_answers - known_answers: + return False + return True + + def async_expire(self, now: float) -> None: + """Expire the history of old questions.""" + removes = [ + question + for question, now_known_answers in self._history.items() + if now - now_known_answers[0] > _DUPLICATE_QUESTION_INTERVAL + ] + for question in removes: + del self._history[question] diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 0bc92686..567ef2b6 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -31,6 +31,7 @@ from .._cache import _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord +from .._logger import log from .._protocol import DNSOutgoing from .._services import ( RecordUpdateListener, @@ -135,11 +136,16 @@ def generate_service_query( questions_with_known_answers: _QuestionWithKnownAnswers = {} for type_ in types_: question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) - questions_with_known_answers[question] = set( + known_answers = set( cast(DNSPointer, record) for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) if not record.is_stale(now) ) + if multicast and zc.question_history.suppresses(question, now, cast(Set[DNSRecord], known_answers)): + log.debug("Asking %s was suppressed by the question history", question) + continue + questions_with_known_answers[question] = known_answers + return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) diff --git a/zeroconf/const.py b/zeroconf/const.py index ba9d5309..df1ba8be 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -30,6 +30,7 @@ _REGISTER_TIME = 225 # ms _LISTENER_TIME = 200 # ms _BROWSER_TIME = 1000 # ms +_DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms _BROWSER_BACKOFF_LIMIT = 3600 # s _CACHE_CLEANUP_INTERVAL = 10000 # ms From ac9f72a986ae314af0043cae6fb6219baabea7e6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 14:25:29 -1000 Subject: [PATCH 414/608] Fix Responding to Address Queries (RFC6762 section 6.2) (#777) --- tests/test_handlers.py | 53 ++++++++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 12 +++++++--- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d77db3f0..ccb2fcf5 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -16,8 +16,10 @@ import zeroconf as r from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const +from zeroconf._dns import DNSRRSet from zeroconf.aio import AsyncZeroconf + from . import _clear_cache, _inject_response log = logging.getLogger('zeroconf') @@ -325,6 +327,57 @@ def test_aaaa_query(): zc.close() +def test_a_and_aaaa_record_fate_sharing(): + """Test that queries for AAAA always return A records in the additionals.""" + zc = Zeroconf(interfaces=['127.0.0.1']) + type_ = "_a-and-aaaa-service._tcp.local." + name = "knownname" + registration_name = "%s.%s" % (name, type_) + desc = {'path': '/~paulsm/'} + server_name = "ash-2.local." + ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") + ipv4_address = socket.inet_aton("10.0.1.2") + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address, ipv4_address]) + aaaa_record = info.dns_addresses(version=r.IPVersion.V6Only)[0] + a_record = info.dns_addresses(version=r.IPVersion.V4Only)[0] + + zc.registry.add(info) + + # Test AAAA query + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + _, multicast_out = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) + answers = DNSRRSet([answer[0] for answer in multicast_out.answers]) + additionals = DNSRRSet(multicast_out.additionals) + assert aaaa_record in answers + assert a_record in additionals + assert len(multicast_out.answers) == 1 + assert len(multicast_out.additionals) == 1 + + # Test A query + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) + generated.add_question(question) + packets = generated.packets() + _, multicast_out = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT + ) + answers = DNSRRSet([answer[0] for answer in multicast_out.answers]) + additionals = DNSRRSet(multicast_out.additionals) + + assert a_record in answers + assert aaaa_record in additionals + assert len(multicast_out.answers) == 1 + assert len(multicast_out.additionals) == 1 + # unregister + zc.registry.remove(info) + zc.close() + + def test_unicast_response(): """Ensure we send a unicast response when the source port is not the MDNS port.""" # instantiate a zeroconf instance diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 79bb6f90..9caae3a2 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -204,9 +204,15 @@ def _add_address_answers( ) -> None: """Answer A/AAAA/ANY question.""" for service in self.registry.get_infos_server(name): - for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_], created=now): - if not known_answers.suppresses(dns_address): - answer_set[dns_address] = set() + answers: List[DNSAddress] = [] + additionals: Set[DNSRecord] = set() + for dns_address in service.dns_addresses(created=now): + if dns_address.type != type_: + additionals.add(dns_address) + elif not known_answers.suppresses(dns_address): + answers.append(dns_address) + for answer in answers: + answer_set[answer] = additionals def _answer_question( self, From 767ae8f6cd92493f8f43d66edc70c8fd856ed11e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 14:43:07 -1000 Subject: [PATCH 415/608] Reformat test_handlers (#780) --- tests/test_handlers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ccb2fcf5..522c71d6 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -337,7 +337,9 @@ def test_a_and_aaaa_record_fate_sharing(): server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") ipv4_address = socket.inet_aton("10.0.1.2") - info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address, ipv4_address]) + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address, ipv4_address] + ) aaaa_record = info.dns_addresses(version=r.IPVersion.V6Only)[0] a_record = info.dns_addresses(version=r.IPVersion.V4Only)[0] @@ -373,7 +375,7 @@ def test_a_and_aaaa_record_fate_sharing(): assert aaaa_record in additionals assert len(multicast_out.answers) == 1 assert len(multicast_out.additionals) == 1 - # unregister + # unregister zc.registry.remove(info) zc.close() From 7aeafbf3b990ab671ff691b6c20cd410f69808bf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 14:43:16 -1000 Subject: [PATCH 416/608] Switch to using a simple cache instead of lru_cache (#779) --- zeroconf/_protocol.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index b0b87a06..f3175004 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -22,7 +22,6 @@ import enum import struct -from functools import lru_cache from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast @@ -94,6 +93,7 @@ def __init__(self, data: bytes) -> None: self.num_additionals = 0 self.valid = False self.now = current_time_millis() + self._utf_cache: Dict[Tuple[int, int], str] = {} try: self.read_header() @@ -208,10 +208,15 @@ def read_others(self) -> None: if rec is not None: self.answers.append(rec) - @lru_cache(maxsize=None) def read_utf(self, offset: int, length: int) -> str: """Reads a UTF-8 string of a given length from the packet""" - return str(self.data[offset : offset + length], 'utf-8', 'replace') + key = (offset, length) + cached = self._utf_cache.get(key) + if cached is not None: + return cached + decoded = str(self.data[offset : offset + length], 'utf-8', 'replace') + self._utf_cache[key] = decoded + return decoded def read_name(self) -> str: """Reads a domain name from the packet""" From 1b873436e2d9ff36876a71c48fa697d277fd3ffa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 14:51:19 -1000 Subject: [PATCH 417/608] Drop utf cache from _dns (#781) - The cache did not make enough difference to justify the additional complexity after additional testing was done --- zeroconf/_protocol.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index f3175004..80ca7b88 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -93,7 +93,6 @@ def __init__(self, data: bytes) -> None: self.num_additionals = 0 self.valid = False self.now = current_time_millis() - self._utf_cache: Dict[Tuple[int, int], str] = {} try: self.read_header() @@ -210,13 +209,7 @@ def read_others(self) -> None: def read_utf(self, offset: int, length: int) -> str: """Reads a UTF-8 string of a given length from the packet""" - key = (offset, length) - cached = self._utf_cache.get(key) - if cached is not None: - return cached - decoded = str(self.data[offset : offset + length], 'utf-8', 'replace') - self._utf_cache[key] = decoded - return decoded + return str(self.data[offset : offset + length], 'utf-8', 'replace') def read_name(self) -> str: """Reads a domain name from the packet""" From 3be1bc84bff5ee2840040ddff41185b257a1055c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 15:26:25 -1000 Subject: [PATCH 418/608] Inline utf8 decoding when processing incoming packets (#782) --- zeroconf/_protocol.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 80ca7b88..53a7b710 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -207,10 +207,6 @@ def read_others(self) -> None: if rec is not None: self.answers.append(rec) - def read_utf(self, offset: int, length: int) -> str: - """Reads a UTF-8 string of a given length from the packet""" - return str(self.data[offset : offset + length], 'utf-8', 'replace') - def read_name(self) -> str: """Reads a domain name from the packet""" result = '' @@ -218,6 +214,7 @@ def read_name(self) -> str: next_ = -1 first = off + # This is a tight loop that is called frequently, small optimizations can make a difference. while True: length = self.data[off] off += 1 @@ -225,7 +222,8 @@ def read_name(self) -> str: break t = length & 0xC0 if t == 0x00: - result += self.read_utf(off, length) + '.' + # Convert to utf-8 + result += str(self.data[off : off + length], 'utf-8', 'replace') + '.' off += length elif t == 0xC0: if next_ < 0: From dd85ae7defd3f195ed0511a2fdb6512326ca0562 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 15:28:30 -1000 Subject: [PATCH 419/608] Add a guard to prevent running ServiceInfo.request in async context (#784) * Add a guard to prevent running ServiceInfo.request in async context * test --- tests/test_aio.py | 10 ++++++++++ zeroconf/_services/info.py | 3 +++ 2 files changed, 13 insertions(+) diff --git a/tests/test_aio.py b/tests/test_aio.py index 327ccc66..4cb7af2e 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -597,3 +597,13 @@ async def test_async_zeroconf_service_types(): finally: await zeroconf_registrar.async_close() + + +@pytest.mark.asyncio +async def test_guard_against_running_serviceinfo_request_event_loop() -> None: + """Test that running ServiceInfo.request from the event loop throws.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + + service_info = AsyncServiceInfo("_hap._tcp.local.", "doesnotmatter._hap._tcp.local.") + with pytest.raises(RuntimeError): + service_info.request(aiozc.zeroconf, 3000) diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index aa457ffd..6fccb6b7 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -28,6 +28,7 @@ from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing from .._services import RecordUpdateListener +from .._utils.aio import get_running_loop from .._utils.name import service_type_name from .._utils.net import ( IPVersion, @@ -398,6 +399,8 @@ def request(self, zc: 'Zeroconf', timeout: float) -> bool: network, and updates this object with details discovered. """ assert zc.loop is not None and zc.loop.is_running() + if zc.loop == get_running_loop(): + raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") return asyncio.run_coroutine_threadsafe(self.async_request(zc, timeout), zc.loop).result() async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: From 97f5b502815075f2ff29bee3ace7cde6ad725dfb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 15:57:55 -1000 Subject: [PATCH 420/608] Ensure the queue is created before adding listeners to ServiceBrowser (#785) * Ensure the queue is created before adding listeners to ServiceBrowser - The callback from the listener could generate an event that would fire in async context that should have gone to the queue which could result in the consumer running a sync call in the event loop and blocking it. * add comments * add comments * add comments * add comments * black --- zeroconf/_services/browser.py | 22 ++++++++++++++++------ zeroconf/aio.py | 2 ++ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 567ef2b6..655d046a 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -205,8 +205,6 @@ def __init__( self.queue: Optional[queue.Queue] = None self.done = False - self._generate_first_next_time() - if hasattr(handlers, 'add_service'): listener = cast('ServiceListener', handlers) handlers = None @@ -219,6 +217,13 @@ def __init__( for h in handlers: self.service_state_changed.register_handler(h) + def _setup(self) -> None: + """Generate the next time and setup listeners. + + Must be called by uses of this base class after they + have finished setting their properties. + """ + self._generate_first_next_time() self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) def _generate_first_next_time(self) -> None: @@ -412,15 +417,20 @@ def __init__( port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, ) -> None: + assert zc.loop is not None + if not zc.loop.is_running(): + raise RuntimeError("The event loop is not running") threading.Thread.__init__(self) super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + # Add the queue before the listener is installed in _setup + # to ensure that events run in the dedicated thread and do + # not block the event loop self.queue = get_best_available_queue() self.daemon = True - assert self.zc.loop is not None - if not self.zc.loop.is_running(): - raise RuntimeError("The event loop is not running") - self.zc.loop.call_soon_threadsafe(self._async_start_browser) self.start() + self._setup() + # Start queries after the listener is installed in _setup + zc.loop.call_soon_threadsafe(self._async_start_browser) self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 7a107fcc..94fa82bb 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -72,6 +72,8 @@ def __init__( delay: int = _BROWSER_TIME, ) -> None: super().__init__(zeroconf, type_, handlers, listener, addr, port, delay) + self._setup() + # Start queries after the listener is installed in _setup self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) async def async_cancel(self) -> None: From 3b3ecf09d2f30ee39c6c29b4d85e000577b2c4b9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 16:03:24 -1000 Subject: [PATCH 421/608] Update changelog (#786) --- README.rst | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/README.rst b/README.rst index a26981b4..c9e23afb 100644 --- a/README.rst +++ b/README.rst @@ -236,6 +236,10 @@ you can likely not be concerned with the breaking changes below: * TRAFFIC REDUCTION: Efficiently bucket queries with known answers (#698) @bdraco +* TRAFFIC REDUCTION: Implement duplicate question supression (#770) @bdraco + + http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + * MAJOR BUG: Ensure matching PTR queries are returned with the ANY query (#618) @bdraco * MAJOR BUG: Fix lookup of uppercase names in registry (#597) @bdraco @@ -254,6 +258,35 @@ you can likely not be concerned with the breaking changes below: * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco +* Switch to using an asyncio.Event for async_wait (#759) @bdraco + + We no longer need to check for thread safety under a asyncio.Condition + as the ServiceBrowser and ServiceInfo internals schedule coroutines + in the eventloop. + +* Ensure the queue is created before adding listeners to ServiceBrowser (#785) @bdraco + + The callback from the listener could generate an event that would + fire in async context that should have gone to the queue which + could result in the consumer running a sync call in the event loop + and blocking it. + +* Add a guard to prevent running ServiceInfo.request in async context (#784) @bdraco + +* Inline utf8 decoding when processing incoming packets (#782) @bdraco + +* Drop utf cache from _dns (#781) (later reverted) @bdraco + +* Switch to using a simple cache instead of lru_cache (#779) (later reverted) @bdraco + +* Fix Responding to Address Queries (RFC6762 section 6.2) (#777) @bdraco + +* Fix deadlock on ServiceBrowser shutdown with PyPy (#774) @bdraco + +* Add a guard against the task list changing when shutting down (#776) @bdraco + +* Improve performance of parsing DNSIncoming by caching read_utf (#769) (later reverted) @bdraco + * Simplify ServiceBrowser callsbacks (#756) @bdraco * Revert: Fix thread safety in _ServiceBrowser.update_records_complete (#708) (#755) @bdraco From 135983cb96a27e3ad3750234286d1d9bfa6ff44f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 18:26:13 -1000 Subject: [PATCH 422/608] Add support for requesting QU questions to ServiceBrowser and ServiceInfo (#787) --- tests/services/test_browser.py | 100 +++++++++++++++++++++++++++++++++ tests/services/test_info.py | 48 ++++++++++++++++ zeroconf/__init__.py | 3 +- zeroconf/_core.py | 8 ++- zeroconf/_dns.py | 14 +++++ zeroconf/_services/browser.py | 35 ++++++++++-- zeroconf/_services/info.py | 26 ++++++--- zeroconf/aio.py | 24 ++++++-- 8 files changed, 236 insertions(+), 22 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 69c23cb0..2198a33b 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -542,6 +542,106 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser.close() +def test_asking_default_is_asking_qm_questions(): + """Verify the service browser can ask QU questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert first_outgoing.questions[0].unicast == False + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_asking_qm_questions(): + """Verify explictly asking QM questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser( + zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QM + ) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert first_outgoing.questions[0].unicast == False + finally: + browser.cancel() + zeroconf_browser.close() + + +def test_asking_qu_questions(): + """Verify the service browser can ask QU questions.""" + type_ = "_quservice._tcp.local." + zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_browser.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + browser = ServiceBrowser( + zeroconf_browser, type_, [on_service_state_change], question_type=r.DNSQuestionType.QU + ) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + try: + assert first_outgoing.questions[0].unicast == True + finally: + browser.cancel() + zeroconf_browser.close() + + def test_integration(): service_added = Event() service_removed = Event() diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 8fb45f22..438ad819 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -625,3 +625,51 @@ def test_serviceinfo_accepts_bytes_or_string_dict(): addresses=addresses, ) assert info_service.dns_text().text == b'\x0epath=/~paulsm/' + + +def test_asking_qu_questions(): + """Verify explictly asking QU questions.""" + type_ = "_quservice._tcp.local." + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf, "async_send", send): + zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QU) + assert first_outgoing.questions[0].unicast == True + zeroconf.close() + + +def test_asking_qm_questions_are_default(): + """Verify default is QM questions.""" + type_ = "_quservice._tcp.local." + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf.async_send + + first_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with unittest.mock.patch.object(zeroconf, "async_send", send): + zeroconf.get_service_info(f"name.{type_}", type_, 500) + assert first_outgoing.questions[0].unicast == False + zeroconf.close() diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e3d7ddfb..e836195e 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -33,6 +33,7 @@ DNSRecord, DNSService, DNSText, + DNSQuestionType, ) from ._logger import QuietLogger, log # noqa # import needed for backwards compat from ._exceptions import ( # noqa # import needed for backwards compat @@ -84,7 +85,7 @@ __all__ = [ "__version__", - "DNSOutgoing", + "DNSQuestionType", "Zeroconf", "ServiceInfo", "ServiceBrowser", diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 83395f05..a22ede3d 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -32,7 +32,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache -from ._dns import DNSQuestion +from ._dns import DNSQuestion, DNSQuestionType from ._exceptions import NonUniqueNameException from ._handlers import QueryHandler, RecordManager from ._history import QuestionHistory @@ -360,12 +360,14 @@ def async_notify_all(self) -> None: self.notify_event.set() self.notify_event.clear() - def get_service_info(self, type_: str, name: str, timeout: int = 3000) -> Optional[ServiceInfo]: + def get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[ServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds.""" info = ServiceInfo(type_, name) - if info.request(self, timeout): + if info.request(self, timeout, question_type): return info return None diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 0fad150d..2c0e3338 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -20,6 +20,7 @@ USA """ +import enum import socket from typing import Any, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union, cast @@ -49,6 +50,19 @@ from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import +@enum.unique +class DNSQuestionType(enum.Enum): + """An MDNS question type. + + "QU" - questions requesting unicast responses + "QM" - questions requesting multicast responses + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + """ + + QU = 1 + QM = 2 + + def dns_entry_matches(record: 'DNSEntry', key: str, type_: int, class_: int) -> bool: return key == record.key and type_ == record.type and class_ == record.class_ diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 655d046a..30f4cf90 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -30,7 +30,7 @@ from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from .._cache import _UniqueRecordsType -from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord +from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord from .._logger import log from .._protocol import DNSOutgoing from .._services import ( @@ -130,12 +130,18 @@ def _group_ptr_queries_with_known_answers( def generate_service_query( - zc: 'Zeroconf', now: float, types_: List[str], multicast: bool = True + zc: 'Zeroconf', + now: float, + types_: List[str], + multicast: bool = True, + question_type: Optional[DNSQuestionType] = None, ) -> List[DNSOutgoing]: """Generate a service query for sending with zeroconf.send.""" questions_with_known_answers: _QuestionWithKnownAnswers = {} + qu_question = not multicast if question_type is None else question_type == DNSQuestionType.QU for type_ in types_: question = DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) + question.unicast = qu_question known_answers = set( cast(DNSPointer, record) for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) @@ -186,8 +192,25 @@ def __init__( addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, ) -> None: - """Creates a browser for a specific type""" + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + + The listener object will have its add_service() and + remove_service() methods called when this browser + discovers changes in the services availability. + """ assert handlers or listener, 'You need to specify at least one handler' self.types: Set[str] = set(type_ if isinstance(type_, list) else [type_]) for check_type_ in self.types: @@ -198,6 +221,7 @@ def __init__( self.addr = addr self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) + self.question_type = question_type self._next_time: Dict[str, float] = {} self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self.types} self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() @@ -370,7 +394,7 @@ def generate_ready_queries(self) -> List[DNSOutgoing]: self._next_time[type_] = now + self._delay[type_] self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - return generate_service_query(self.zc, now, ready_types, self.multicast) + return generate_service_query(self.zc, now, ready_types, self.multicast, self.question_type) def _millis_to_wait(self, now: float) -> Optional[float]: """Returns the number of milliseconds to wait for the next event.""" @@ -416,12 +440,13 @@ def __init__( addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, ) -> None: assert zc.loop is not None if not zc.loop.is_running(): raise RuntimeError("The event loop is not running") threading.Thread.__init__(self) - super().__init__(zc, type_, handlers=handlers, listener=listener, addr=addr, port=port, delay=delay) + super().__init__(zc, type_, handlers, listener, addr, port, delay, question_type) # Add the queue before the listener is installed in _setup # to ensure that events run in the dedicated thread and do # not block the event loop diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 6fccb6b7..82e8b983 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -24,7 +24,7 @@ import socket from typing import Dict, List, Optional, TYPE_CHECKING, Union, cast -from .._dns import DNSAddress, DNSPointer, DNSRecord, DNSService, DNSText +from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing from .._services import RecordUpdateListener @@ -85,7 +85,6 @@ class ServiceInfo(RecordUpdateListener): * `other_ttl`: ttl used for PTR/TXT records * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, or in parsed form as text; at most one of those parameters can be provided) - """ text = b'' @@ -103,7 +102,7 @@ def __init__( other_ttl: int = _DNS_OTHER_TTL, *, addresses: Optional[List[bytes]] = None, - parsed_addresses: Optional[List[str]] = None + parsed_addresses: Optional[List[str]] = None, ) -> None: # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: @@ -394,16 +393,22 @@ def _is_complete(self) -> bool: """The ServiceInfo has all expected properties.""" return not (self.text is None or not self._addresses) - def request(self, zc: 'Zeroconf', timeout: float) -> bool: + def request( + self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None + ) -> bool: """Returns true if the service could be discovered on the network, and updates this object with details discovered. """ assert zc.loop is not None and zc.loop.is_running() if zc.loop == get_running_loop(): raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") - return asyncio.run_coroutine_threadsafe(self.async_request(zc, timeout), zc.loop).result() + return asyncio.run_coroutine_threadsafe( + self.async_request(zc, timeout, question_type), zc.loop + ).result() - async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: + async def async_request( + self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None + ) -> bool: """Returns true if the service could be discovered on the network, and updates this object with details discovered. """ @@ -421,7 +426,7 @@ async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: if last <= now: return False if next_ <= now: - out = self.generate_request_query(zc, now) + out = self.generate_request_query(zc, now, question_type) if not out.questions: return self.load_from_cache(zc) zc.async_send(out) @@ -435,13 +440,18 @@ async def async_request(self, zc: 'Zeroconf', timeout: float) -> bool: return True - def generate_request_query(self, zc: 'Zeroconf', now: float) -> DNSOutgoing: + def generate_request_query( + self, zc: 'Zeroconf', now: float, question_type: Optional[DNSQuestionType] = None + ) -> DNSOutgoing: """Generate the request query.""" out = DNSOutgoing(_FLAGS_QR_QUERY) out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_SRV, _CLASS_IN) out.add_question_or_one_cache(zc.cache, now, self.name, _TYPE_TXT, _CLASS_IN) out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_A, _CLASS_IN) out.add_question_or_all_cache(zc.cache, now, self.server, _TYPE_AAAA, _CLASS_IN) + if question_type == DNSQuestionType.QU: + for question in out.questions: + question.unicast = True return out def __eq__(self, other: object) -> bool: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 94fa82bb..dc4f90cd 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -25,6 +25,7 @@ from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast from ._core import Zeroconf +from ._dns import DNSQuestionType from ._exceptions import NonUniqueNameException from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase @@ -55,11 +56,23 @@ class AsyncServiceInfo(ServiceInfo): class AsyncServiceBrowser(_ServiceBrowserBase): - """Used to browse for a service of a specific type. + """Used to browse for a service for specific type(s). + + Constructor parameters are as follows: + + * `zc`: A Zeroconf instance + * `type_`: fully qualified service type name + * `handler`: ServiceListener or Callable that knows how to process ServiceStateChange events + * `listener`: ServiceListener + * `addr`: address to send queries (will default to multicast) + * `port`: port to send queries (will default to mdns 5353) + * `delay`: The initial delay between answering questions + * `question_type`: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) The listener object will have its add_service() and remove_service() methods called when this browser - discovers changes in the services availability.""" + discovers changes in the services availability. + """ def __init__( self, @@ -70,8 +83,9 @@ def __init__( addr: Optional[str] = None, port: int = _MDNS_PORT, delay: int = _BROWSER_TIME, + question_type: Optional[DNSQuestionType] = None, ) -> None: - super().__init__(zeroconf, type_, handlers, listener, addr, port, delay) + super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) self._setup() # Start queries after the listener is installed in _setup self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) @@ -251,13 +265,13 @@ async def async_close(self) -> None: await self.zeroconf._async_close() # pylint: disable=protected-access async def async_get_service_info( - self, type_: str, name: str, timeout: int = 3000 + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None ) -> Optional[AsyncServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, which defaults to 3 seconds.""" info = AsyncServiceInfo(type_, name) - if await info.async_request(self.zeroconf, timeout): + if await info.async_request(self.zeroconf, timeout, question_type): return info return None From 62dc9c91c277bc4755f81597adca030a43d0ce5f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 19:40:33 -1000 Subject: [PATCH 423/608] Add async_apple_scanner example (#719) --- examples/async_apple_scanner.py | 119 ++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 examples/async_apple_scanner.py diff --git a/examples/async_apple_scanner.py b/examples/async_apple_scanner.py new file mode 100644 index 00000000..573640b0 --- /dev/null +++ b/examples/async_apple_scanner.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 + +""" Scan for apple devices. """ + +import argparse +import asyncio +import logging +from typing import Any, Optional, cast + +from zeroconf import DNSQuestionType, IPVersion, ServiceStateChange, Zeroconf +from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf + +HOMESHARING_SERVICE: str = "_appletv-v2._tcp.local." +DEVICE_SERVICE: str = "_touch-able._tcp.local." +MEDIAREMOTE_SERVICE: str = "_mediaremotetv._tcp.local." +AIRPLAY_SERVICE: str = "_airplay._tcp.local." +COMPANION_SERVICE: str = "_companion-link._tcp.local." +RAOP_SERVICE: str = "_raop._tcp.local." +AIRPORT_ADMIN_SERVICE: str = "_airport._tcp.local." +DEVICE_INFO_SERVICE: str = "_device-info._tcp.local." + +ALL_SERVICES = [ + HOMESHARING_SERVICE, + DEVICE_SERVICE, + MEDIAREMOTE_SERVICE, + AIRPLAY_SERVICE, + COMPANION_SERVICE, + RAOP_SERVICE, + AIRPORT_ADMIN_SERVICE, + DEVICE_INFO_SERVICE, +] + +log = logging.getLogger(__name__) + + +def async_on_service_state_change( + zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange +) -> None: + print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + if state_change is not ServiceStateChange.Added: + return + base_name = name[: -len(service_type) - 1] + device_name = f"{base_name}.{DEVICE_INFO_SERVICE}" + asyncio.ensure_future(_async_show_service_info(zeroconf, service_type, name)) + # Also probe for device info + asyncio.ensure_future(_async_show_service_info(zeroconf, DEVICE_INFO_SERVICE, device_name)) + + +async def _async_show_service_info(zeroconf: Zeroconf, service_type: str, name: str) -> None: + info = AsyncServiceInfo(service_type, name) + await info.async_request(zeroconf, 3000, question_type=DNSQuestionType.QU) + print("Info from zeroconf.get_service_info: %r" % (info)) + if info: + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + print(" Name: %s" % name) + print(" Addresses: %s" % ", ".join(addresses)) + print(" Weight: %d, priority: %d" % (info.weight, info.priority)) + print(" Server: %s" % (info.server,)) + if info.properties: + print(" Properties are:") + for key, value in info.properties.items(): + print(" %s: %s" % (key, value)) + else: + print(" No properties") + else: + print(" No info") + print('\n') + + +class AsyncAppleScanner: + def __init__(self, args: Any) -> None: + self.args = args + self.aiobrowser: Optional[AsyncServiceBrowser] = None + self.aiozc: Optional[AsyncZeroconf] = None + + async def async_run(self) -> None: + self.aiozc = AsyncZeroconf(ip_version=ip_version) + await self.aiozc.zeroconf.async_wait_for_start() + print("\nBrowsing %s service(s), press Ctrl-C to exit...\n" % ALL_SERVICES) + kwargs = {'handlers': [async_on_service_state_change], 'question_type': DNSQuestionType.QU} + if self.args.target: + kwargs["addr"] = self.args.target + self.aiobrowser = AsyncServiceBrowser(self.aiozc.zeroconf, ALL_SERVICES, **kwargs) # type: ignore + while True: + await asyncio.sleep(1) + + async def async_close(self) -> None: + assert self.aiozc is not None + assert self.aiobrowser is not None + await self.aiobrowser.async_cancel() + await self.aiozc.async_close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + + parser = argparse.ArgumentParser() + parser.add_argument('--debug', action='store_true') + version_group = parser.add_mutually_exclusive_group() + version_group.add_argument('--target', help='Unicast target') + version_group.add_argument('--v6', action='store_true') + version_group.add_argument('--v6-only', action='store_true') + args = parser.parse_args() + + if args.debug: + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + if args.v6: + ip_version = IPVersion.All + elif args.v6_only: + ip_version = IPVersion.V6Only + else: + ip_version = IPVersion.V4Only + + loop = asyncio.get_event_loop() + runner = AsyncAppleScanner(args) + try: + loop.run_until_complete(runner.async_run()) + except KeyboardInterrupt: + loop.run_until_complete(runner.async_close()) From 5d2362825110e9f7a9c9259218a664e2e927e821 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 19:43:32 -1000 Subject: [PATCH 424/608] Update changelog (#788) --- README.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index c9e23afb..c0b12b15 100644 --- a/README.rst +++ b/README.rst @@ -258,11 +258,9 @@ you can likely not be concerned with the breaking changes below: * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco -* Switch to using an asyncio.Event for async_wait (#759) @bdraco +* Add async_apple_scanner example (#719) @bdraco - We no longer need to check for thread safety under a asyncio.Condition - as the ServiceBrowser and ServiceInfo internals schedule coroutines - in the eventloop. +* Add support for requesting QU questions to ServiceBrowser and ServiceInfo (#787) @bdraco * Ensure the queue is created before adding listeners to ServiceBrowser (#785) @bdraco @@ -287,6 +285,12 @@ you can likely not be concerned with the breaking changes below: * Improve performance of parsing DNSIncoming by caching read_utf (#769) (later reverted) @bdraco +* Switch to using an asyncio.Event for async_wait (#759) @bdraco + + We no longer need to check for thread safety under a asyncio.Condition + as the ServiceBrowser and ServiceInfo internals schedule coroutines + in the eventloop. + * Simplify ServiceBrowser callsbacks (#756) @bdraco * Revert: Fix thread safety in _ServiceBrowser.update_records_complete (#708) (#755) @bdraco From ecad4e84c44ffd21dbf15e969c08f7b3376b131c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 20 Jun 2021 22:53:32 -1000 Subject: [PATCH 425/608] Ensure outgoing ServiceBrowser questions are seen by the question history (#790) --- tests/services/test_browser.py | 11 +++++++++++ zeroconf/_services/browser.py | 1 + 2 files changed, 12 insertions(+) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 2198a33b..2dba6158 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -1045,4 +1045,15 @@ async def test_generate_service_query_suppress_duplicate_questions(): # We do not suppress QU queries ever outs = _services_browser.generate_service_query(zc, now, [name], multicast=False) assert outs + + zc.question_history.async_expire(now + 1000) + # No suppression after clearing the history + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert outs + + # The previous query we just sent is still remembered and + # the next one is suppressed + outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) + assert not outs + await aiozc.async_close() diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 30f4cf90..6a679c20 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -151,6 +151,7 @@ def generate_service_query( log.debug("Asking %s was suppressed by the question history", question) continue questions_with_known_answers[question] = known_answers + zc.question_history.add_question_at_time(question, now, cast(Set[DNSRecord], known_answers)) return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) From 6aac0eb0c1e394ec7ee21ddd6e98e446417d0e07 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 07:01:23 -1000 Subject: [PATCH 426/608] Fix test_tc_bit_defers_last_response_missing failures due to thread safety (#795) --- tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 6447bbf3..27cfb9e7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -576,7 +576,7 @@ async def make_query(): for _ in range(8): time.sleep(0.1) - if source_ip not in zc._timers: + if source_ip not in zc._timers and source_ip not in zc._deferred: break assert source_ip not in zc._deferred From 2bfbcbe9e05b9df98bba66a73deb0041c0e7c13b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 07:09:58 -1000 Subject: [PATCH 427/608] Make add_listener and remove_listener threadsafe (#794) --- tests/services/test_browser.py | 2 ++ zeroconf/_core.py | 34 ++++++++++++++++++++++++++++++---- zeroconf/_handlers.py | 18 ++++++++++++------ zeroconf/_services/browser.py | 24 ++++++++++-------------- zeroconf/_services/info.py | 4 ++-- zeroconf/aio.py | 8 +++----- 6 files changed, 59 insertions(+), 31 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 2dba6158..331e1f0d 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -315,6 +315,7 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi finally: assert len(zeroconf.listeners) == 1 service_browser.cancel() + time.sleep(0.2) assert len(zeroconf.listeners) == 0 zeroconf.remove_all_service_listeners() zeroconf.close() @@ -422,6 +423,7 @@ def _mock_get_expiration_time(self, percent): finally: assert len(zeroconf.listeners) == 1 service_browser.cancel() + time.sleep(0.2) assert len(zeroconf.listeners) == 0 zeroconf.remove_all_service_listeners() zeroconf.close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a22ede3d..8e9f3dff 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -548,12 +548,38 @@ def add_listener( ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to - answer the question(s).""" - self.record_manager.add_listener(listener, question) + answer the question(s). + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_add_listener, listener, question) def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" - self.record_manager.remove_listener(listener) + """Removes a listener. + + This function is threadsafe + """ + assert self.loop is not None + self.loop.call_soon_threadsafe(self.record_manager.async_remove_listener, listener) + + def async_add_listener( + self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] + ) -> None: + """Adds a listener for a given question. The listener will have + its update_record method called when information is available to + answer the question(s). + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_add_listener(listener, question) + + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ + self.record_manager.async_remove_listener(listener) def handle_response(self, msg: DNSIncoming) -> None: """Deal with incoming response packets. All answers diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 9caae3a2..711c84c9 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -387,12 +387,15 @@ def _async_mark_unique_cached_records_older_than_1s_to_expire( # Expire in 1s entry.set_created_ttl(now, 1) - def add_listener( + def async_add_listener( self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]] ) -> None: """Adds a listener for a given question. The listener will have its update_record method called when information is available to - answer the question(s).""" + answer the question(s). + + This function is not threadsafe and must be called in the eventloop. + """ self.listeners.append(listener) if question is None: @@ -400,7 +403,7 @@ def add_listener( questions = [question] if isinstance(question, DNSQuestion) else question assert self.zc.loop is not None - self.zc.loop.call_soon_threadsafe(self._async_update_matching_records, listener, questions) + self._async_update_matching_records(listener, questions) def _async_update_matching_records( self, listener: RecordUpdateListener, questions: List[DNSQuestion] @@ -422,10 +425,13 @@ def _async_update_matching_records( listener.async_update_records_complete() self.zc.async_notify_all() - def remove_listener(self, listener: RecordUpdateListener) -> None: - """Removes a listener.""" + def async_remove_listener(self, listener: RecordUpdateListener) -> None: + """Removes a listener. + + This function is not threadsafe and must be called in the eventloop. + """ try: self.listeners.remove(listener) - self.zc.notify_all() + self.zc.async_notify_all() except ValueError as e: log.exception('Failed to remove listener: %r', e) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 6a679c20..698a162a 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -242,14 +242,16 @@ def __init__( for h in handlers: self.service_state_changed.register_handler(h) - def _setup(self) -> None: + def _async_start(self) -> None: """Generate the next time and setup listeners. Must be called by uses of this base class after they have finished setting their properties. """ self._generate_first_next_time() - self.zc.add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) + # Only start queries after the listener is installed + self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) def _generate_first_next_time(self) -> None: """Generate the initial next query times. @@ -374,10 +376,10 @@ def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], Servic state_change=state_change, ) - def cancel(self) -> None: + def _async_cancel(self) -> None: """Cancel the browser.""" self.done = True - self.zc.remove_listener(self) + self.zc.async_remove_listener(self) def generate_ready_queries(self) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" @@ -454,20 +456,15 @@ def __init__( self.queue = get_best_available_queue() self.daemon = True self.start() - self._setup() - # Start queries after the listener is installed in _setup - zc.loop.call_soon_threadsafe(self._async_start_browser) + zc.loop.call_soon_threadsafe(self._async_start) self.name = "zeroconf-ServiceBrowser-%s-%s" % ( '-'.join([type_[:-7] for type_ in self.types]), getattr(self, 'native_id', self.ident), ) - def _async_start_browser(self) -> None: - """Start the browser from the event loop.""" - self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) - - def _async_cancel_browser_soon(self) -> None: + def _async_cancel_soon(self) -> None: """Cancel the browser from the event loop.""" + self._async_cancel() if self._browser_task: asyncio.ensure_future(self._async_cancel_browser()) @@ -476,8 +473,7 @@ def cancel(self) -> None: assert self.zc.loop is not None assert self.queue is not None self.queue.put(None) - self.zc.loop.call_soon_threadsafe(self._async_cancel_browser_soon) - super().cancel() + self.zc.loop.call_soon_threadsafe(self._async_cancel_soon) self.join() def run(self) -> None: diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 82e8b983..d8268d3e 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -421,7 +421,7 @@ async def async_request( last = now + timeout await zc.async_wait_for_start() try: - zc.add_listener(self, None) + zc.async_add_listener(self, None) while not self._is_complete: if last <= now: return False @@ -436,7 +436,7 @@ async def async_request( await zc.async_wait(min(next_, last) - now) now = current_time_millis() finally: - zc.remove_listener(self) + zc.async_remove_listener(self) return True diff --git a/zeroconf/aio.py b/zeroconf/aio.py index dc4f90cd..985a440b 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -22,7 +22,7 @@ import asyncio import contextlib from types import TracebackType # noqa # used in type hints -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from ._core import Zeroconf from ._dns import DNSQuestionType @@ -86,15 +86,13 @@ def __init__( question_type: Optional[DNSQuestionType] = None, ) -> None: super().__init__(zeroconf, type_, handlers, listener, addr, port, delay, question_type) - self._setup() - # Start queries after the listener is installed in _setup - self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + self._async_start() async def async_cancel(self) -> None: """Cancel the browser.""" + self._async_cancel() with contextlib.suppress(asyncio.CancelledError): await self._async_cancel_browser() - super().cancel() class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): From cb91484670ba76c8c453dc49502e89195561b31e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 08:42:16 -1000 Subject: [PATCH 428/608] Remove unused constant from zeroconf._handlers (#796) --- zeroconf/_handlers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 711c84c9..d8ec0a89 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -30,7 +30,6 @@ from ._protocol import DNSIncoming, DNSOutgoing from ._services import RecordUpdateListener from ._services.registry import ServiceRegistry -from ._utils.net import IPVersion from ._utils.time import current_time_millis from .const import ( _CLASS_IN, @@ -47,8 +46,6 @@ _TYPE_TXT, ) -_TYPE_TO_IP_VERSION = {_TYPE_A: IPVersion.V4Only, _TYPE_AAAA: IPVersion.V6Only, _TYPE_ANY: IPVersion.All} - if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 from ._core import Zeroconf # pylint: disable=cyclic-import From d637d67378698e0a505be90afbce4e2264b49444 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 08:56:41 -1000 Subject: [PATCH 429/608] Pass both the new and old records to async_update_records (#792) * Pass the old_record (cached) as the value and the new_record (wire) to async_update_records instead of forcing each consumer to check the cache since we will always have the old_record when generating the async_update_records call. This avoids the overhead of multiple cache lookups for each listener. --- tests/test_handlers.py | 10 +++- tests/test_services.py | 62 ----------------------- tests/test_updates.py | 92 ++++++++++++++++++++++++++++++++++ zeroconf/__init__.py | 2 +- zeroconf/_core.py | 7 ++- zeroconf/_handlers.py | 16 +++--- zeroconf/_services/__init__.py | 39 -------------- zeroconf/_services/browser.py | 16 +++--- zeroconf/_services/info.py | 23 +++++---- zeroconf/_updates.py | 77 ++++++++++++++++++++++++++++ 10 files changed, 210 insertions(+), 134 deletions(-) create mode 100644 tests/test_updates.py create mode 100644 zeroconf/_updates.py diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 522c71d6..53cd6de2 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -986,7 +986,7 @@ async def test_record_update_manager_add_listener_callsback_existing_records(): class MyListener(r.RecordUpdateListener): """A RecordUpdateListener that does not implement update_records.""" - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.DNSRecord]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.RecordUpdate]) -> None: """Update multiple records in one shot.""" updated.extend(records) @@ -1013,7 +1013,13 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.DNSRe ) await asyncio.sleep(0) # flush out the call_soon_threadsafe - assert set(updated) == set([ptr_record, a_record]) + assert set([record.new for record in updated]) == set([ptr_record, a_record]) + + # This behavior is probably wrong but should not be + # changed in this commit because the goal is to refactor + # only and not change how it functions + assert set([record.old for record in updated]) == set([ptr_record, a_record]) + await aiozc.async_close() diff --git a/tests/test_services.py b/tests/test_services.py index 684266f2..12ad95ba 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -14,9 +14,7 @@ import pytest import zeroconf as r -from zeroconf import const from zeroconf import Zeroconf -from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo from . import has_working_ipv6, _clear_cache @@ -193,66 +191,6 @@ def update_service(self, zeroconf, type, name): zeroconf_browser.close() -def test_legacy_record_update_listener(): - """Test a RecordUpdateListener that does not implement update_records.""" - - # instantiate a zeroconf instance - zc = Zeroconf(interfaces=['127.0.0.1']) - - with pytest.raises(RuntimeError): - r.RecordUpdateListener().update_record( - zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) - ) - - updates = [] - - class LegacyRecordUpdateListener(r.RecordUpdateListener): - """A RecordUpdateListener that does not implement update_records.""" - - def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: - nonlocal updates - updates.append(record) - - listener = LegacyRecordUpdateListener() - - zc.add_listener(listener, None) - - # dummy service callback - def on_service_state_change(zeroconf, service_type, state_change, name): - pass - - # start a browser - type_ = "_homeassistant._tcp.local." - name = "MyTestHome" - browser = ServiceBrowser(zc, type_, [on_service_state_change]) - - info_service = ServiceInfo( - type_, - '%s.%s' % (name, type_), - 80, - 0, - 0, - {'path': '/~paulsm/'}, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - - zc.register_service(info_service) - - zc.wait(1) - - browser.cancel() - - assert len(updates) - assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 - - zc.remove_listener(listener) - # Removing a second time should not throw - zc.remove_listener(listener) - - zc.close() - - def test_servicelisteners_raise_not_implemented(): """Verify service listeners raise when one of the methods is not implemented.""" diff --git a/tests/test_updates.py b/tests/test_updates.py new file mode 100644 index 00000000..1f6f8ad4 --- /dev/null +++ b/tests/test_updates.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + + +""" Unit tests for zeroconf._services. """ + +import logging +import socket +from threading import Event + +import pytest + +import zeroconf as r +from zeroconf import const +from zeroconf import Zeroconf +from zeroconf._services.browser import ServiceBrowser +from zeroconf._services.info import ServiceInfo + + +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + + +def test_legacy_record_update_listener(): + """Test a RecordUpdateListener that does not implement update_records.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + with pytest.raises(RuntimeError): + r.RecordUpdateListener().update_record( + zc, 0, r.DNSRecord('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL) + ) + + updates = [] + + class LegacyRecordUpdateListener(r.RecordUpdateListener): + """A RecordUpdateListener that does not implement update_records.""" + + def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None: + nonlocal updates + updates.append(record) + + listener = LegacyRecordUpdateListener() + + zc.add_listener(listener, None) + + # dummy service callback + def on_service_state_change(zeroconf, service_type, state_change, name): + pass + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + browser = ServiceBrowser(zc, type_, [on_service_state_change]) + + info_service = ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + zc.register_service(info_service) + + zc.wait(1) + + browser.cancel() + + assert len(updates) + assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1 + + zc.remove_listener(listener) + # Removing a second time should not throw + zc.remove_listener(listener) + + zc.close() diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e836195e..263043aa 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -49,7 +49,6 @@ from ._services import ( # noqa # import needed for backwards compat Signal, SignalRegistrationInterface, - RecordUpdateListener, ServiceListener, ServiceStateChange, ) @@ -62,6 +61,7 @@ ) from ._services.registry import ServiceRegistry # noqa # import needed for backwards compat from ._services.types import ZeroconfServiceTypes +from ._updates import RecordUpdate, RecordUpdateListener # noqa # import needed for backwards compat from ._utils.name import service_type_name # noqa # import needed for backwards compat from ._utils.net import ( # noqa # import needed for backwards compat add_multicast_member, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 8e9f3dff..2fa093b7 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -38,10 +38,11 @@ from ._history import QuestionHistory from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing -from ._services import RecordUpdateListener, ServiceListener +from ._services import ServiceListener from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry +from ._updates import RecordUpdate, RecordUpdateListener from ._utils.aio import get_running_loop, shutdown_loop, wait_event_or_timeout from ._utils.name import service_type_name from ._utils.net import ( @@ -136,7 +137,9 @@ async def _async_cache_cleanup(self) -> None: while not self.zc.done: now = current_time_millis() self.zc.question_history.async_expire(now) - self.zc.record_manager.async_updates(now, self.zc.cache.async_expire(now)) + self.zc.record_manager.async_updates( + now, [RecordUpdate(record, None) for record in self.zc.cache.async_expire(now)] + ) self.zc.record_manager.async_updates_complete() await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index d8ec0a89..d7900fe9 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -28,8 +28,8 @@ from ._history import QuestionHistory from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing -from ._services import RecordUpdateListener from ._services.registry import ServiceRegistry +from ._updates import RecordUpdate, RecordUpdateListener from ._utils.time import current_time_millis from .const import ( _CLASS_IN, @@ -283,7 +283,7 @@ def __init__(self, zeroconf: 'Zeroconf') -> None: self.cache = zeroconf.cache self.listeners: List[RecordUpdateListener] = [] - def async_updates(self, now: float, rec: List[DNSRecord]) -> None: + def async_updates(self, now: float, records: List[RecordUpdate]) -> None: """Used to notify listeners of new information that has updated a record. @@ -292,7 +292,7 @@ def async_updates(self, now: float, rec: List[DNSRecord]) -> None: This method will be run in the event loop. """ for listener in self.listeners: - listener.async_update_records(self.zc, now, rec) + listener.async_update_records(self.zc, now, records) def async_updates_complete(self) -> None: """Used to notify listeners of new information that has updated @@ -313,7 +313,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: This function must be run in the event loop as it is not threadsafe. """ - updates: List[DNSRecord] = [] + updates: List[RecordUpdate] = [] address_adds: List[DNSAddress] = [] other_adds: List[DNSRecord] = [] removes: List[DNSRecord] = [] @@ -333,11 +333,11 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: address_adds.append(record) else: other_adds.append(record) - updates.append(record) + updates.append(RecordUpdate(record, maybe_entry)) # This is likely a goodbye since the record is # expired and exists in the cache elif maybe_entry is not None: - updates.append(record) + updates.append(RecordUpdate(record, maybe_entry)) removes.append(record) if unique_types: @@ -410,11 +410,11 @@ def _async_update_matching_records( This function must be run from the event loop. """ now = current_time_millis() - records = [] + records: List[RecordUpdate] = [] for question in questions: for record in self.cache.async_entries_with_name(question.name): if not record.is_expired(now) and question.answered_by(record): - records.append(record) + records.append(RecordUpdate(record, record)) if not records: return diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 776d43a7..3759f1ec 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -23,7 +23,6 @@ import enum from typing import Any, Callable, List, TYPE_CHECKING -from .._dns import DNSRecord if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 @@ -72,41 +71,3 @@ def register_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationI def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistrationInterface': self._handlers.remove(handler) return self - - -class RecordUpdateListener: - def update_record( # pylint: disable=no-self-use - self, zc: 'Zeroconf', now: float, record: DNSRecord - ) -> None: - """Update a single record. - - This method is deprecated and will be removed in a future version. - update_records should be implemented instead. - """ - raise RuntimeError("update_record is deprecated and will be removed in a future version.") - - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: - """Update multiple records in one shot. - - All records that are received in a single packet are passed - to update_records. - - This implementation is a compatiblity shim to ensure older code - that uses RecordUpdateListener as a base class will continue to - get calls to update_record. This method will raise - NotImplementedError in a future version. - - At this point the cache will not have the new records - - This method will be run in the event loop. - """ - for record in records: - self.update_record(zc, now, record) - - def async_update_records_complete(self) -> None: - """Called when a record update has completed for all handlers. - - At this point the cache will have the new records. - - This method will be run in the event loop. - """ diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 698a162a..1a7caca8 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -29,17 +29,16 @@ from collections import OrderedDict from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast -from .._cache import _UniqueRecordsType from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord from .._logger import log from .._protocol import DNSOutgoing from .._services import ( - RecordUpdateListener, ServiceListener, ServiceStateChange, Signal, SignalRegistrationInterface, ) +from .._updates import RecordUpdate, RecordUpdateListener from .._utils.aio import get_best_available_queue from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds @@ -294,16 +293,15 @@ def _enqueue_callback( ): self._pending_handlers[key] = state_change - def _async_process_record_update(self, now: float, record: DNSRecord) -> None: + def _async_process_record_update( + self, now: float, record: DNSRecord, old_record: Optional[DNSRecord] + ) -> None: """Process a single record update from a batch of updates.""" expired = record.is_expired(now) if isinstance(record, DNSPointer): if record.name not in self.types: return - old_record = self.zc.cache.async_get_unique( - DNSPointer(record.name, _TYPE_PTR, _CLASS_IN, 0, record.alias) - ) if old_record is None: self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) elif expired: @@ -315,7 +313,7 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None: return # If its expired or already exists in the cache it cannot be updated. - if expired or self.zc.cache.async_get_unique(cast(_UniqueRecordsType, record)): + if expired or old_record: return if isinstance(record, DNSAddress): @@ -332,7 +330,7 @@ def _async_process_record_update(self, now: float, record: DNSRecord) -> None: if type_: self._enqueue_callback(ServiceStateChange.Updated, type_, record.name) - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: """Callback invoked by Zeroconf when new information arrives. Updates information required by browser in the Zeroconf cache. @@ -342,7 +340,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSReco This method will be run in the event loop. """ for record in records: - self._async_process_record_update(now, record) + self._async_process_record_update(now, record[0], record[1]) def async_update_records_complete(self) -> None: """Called when a record update has completed for all handlers. diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index d8268d3e..f7ab9e55 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -27,7 +27,7 @@ from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing -from .._services import RecordUpdateListener +from .._updates import RecordUpdate, RecordUpdateListener from .._utils.aio import get_running_loop from .._utils.name import service_type_name from .._utils.net import ( @@ -258,22 +258,22 @@ def update_record(self, zc: 'Zeroconf', now: float, record: Optional[DNSRecord]) This method will be run in the event loop. """ if record is not None: - self._process_records_threadsafe(zc, now, [record]) + self._process_records_threadsafe(zc, now, [RecordUpdate(record, None)]) - def async_update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: """Updates service information from a DNS record. This method will be run in the event loop. """ self._process_records_threadsafe(zc, now, records) - def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None: + def _process_records_threadsafe(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: """Thread safe record updating.""" update_addresses = False - for record in records: - if isinstance(record, DNSService): + for record_update in records: + if isinstance(record_update[0], DNSService): update_addresses = True - self._process_record_threadsafe(record, now) + self._process_record_threadsafe(record_update[0], now) # Only update addresses if the DNSService (.server) has changed if not update_addresses: @@ -374,17 +374,18 @@ def load_from_cache(self, zc: 'Zeroconf') -> bool: This method is designed to be threadsafe. """ now = current_time_millis() - record_updates = [] + record_updates: List[RecordUpdate] = [] cached_srv_record = zc.cache.get_by_details(self.name, _TYPE_SRV, _CLASS_IN) if cached_srv_record: # If there is a srv record, A and AAAA will already # be called and we do not want to do it twice - record_updates.append(cached_srv_record) + record_updates.append(RecordUpdate(cached_srv_record, None)) else: - record_updates.extend(self._get_address_records_from_cache(zc)) + for record in self._get_address_records_from_cache(zc): + record_updates.append(RecordUpdate(record, None)) cached_txt_record = zc.cache.get_by_details(self.name, _TYPE_TXT, _CLASS_IN) if cached_txt_record: - record_updates.append(cached_txt_record) + record_updates.append(RecordUpdate(cached_txt_record, None)) self._process_records_threadsafe(zc, now, record_updates) return self._is_complete diff --git a/zeroconf/_updates.py b/zeroconf/_updates.py new file mode 100644 index 00000000..d7ad56c1 --- /dev/null +++ b/zeroconf/_updates.py @@ -0,0 +1,77 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from typing import List, NamedTuple, Optional, TYPE_CHECKING + + +from ._dns import DNSRecord + + +if TYPE_CHECKING: + # https://github.com/PyCQA/pylint/issues/3525 + from ._core import Zeroconf # pylint: disable=cyclic-import + + +class RecordUpdate(NamedTuple): + new: DNSRecord + old: Optional[DNSRecord] + + +class RecordUpdateListener: + def update_record( # pylint: disable=no-self-use + self, zc: 'Zeroconf', now: float, record: DNSRecord + ) -> None: + """Update a single record. + + This method is deprecated and will be removed in a future version. + update_records should be implemented instead. + """ + raise RuntimeError("update_record is deprecated and will be removed in a future version.") + + def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None: + """Update multiple records in one shot. + + All records that are received in a single packet are passed + to update_records. + + This implementation is a compatiblity shim to ensure older code + that uses RecordUpdateListener as a base class will continue to + get calls to update_record. This method will raise + NotImplementedError in a future version. + + At this point the cache will not have the new records + + Records are passed as a list of RecordUpdate. This + allows consumers of async_update_records to avoid cache lookups. + + This method will be run in the event loop. + """ + for record in records: + self.update_record(zc, now, record[0]) + + def async_update_records_complete(self) -> None: + """Called when a record update has completed for all handlers. + + At this point the cache will have the new records. + + This method will be run in the event loop. + """ From c36099a41a71298d58e7afa42ecdc7a54d3b010a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 09:07:46 -1000 Subject: [PATCH 430/608] Update changelog (#797) --- README.rst | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index c0b12b15..5c375399 100644 --- a/README.rst +++ b/README.rst @@ -140,8 +140,23 @@ See examples directory for more. Changelog ========= -0.32.0 (Unreleased) -=================== +0.32.0 Beta 2 (Unreleased) +========================== + +* Pass both the new and old records to async_update_records (#792) @bdraco + + Pass the old_record (cached) as the value and the new_record (wire) + to async_update_records instead of forcing each consumer to + check the cache since we will always have the old_record + when generating the async_update_records call. This avoids + the overhead of multiple cache lookups for each listener. + +* Make add_listener and remove_listener threadsafe (#794) @bdraco + +* Ensure outgoing ServiceBrowser questions are seen by the question history (#790) @bdraco + +0.32.0 Beta 1 +============= Documentation for breaking changes era on the side of the caution and likely overstates the risk on many of these. If you are not accessing zeroconf internals, From 38e66ec5ba5fcb96cef17b8949385075807a2fb7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 09:20:14 -1000 Subject: [PATCH 431/608] Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) --- tests/test_aio.py | 52 +++++++++++++++++++++++++++++++++++++++++- tests/test_handlers.py | 7 +++--- zeroconf/_handlers.py | 2 +- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 4cb7af2e..4523ca1e 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -12,7 +12,7 @@ import pytest -from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes from zeroconf import Zeroconf from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered @@ -607,3 +607,53 @@ async def test_guard_against_running_serviceinfo_request_event_loop() -> None: service_info = AsyncServiceInfo("_hap._tcp.local.", "doesnotmatter._hap._tcp.local.") with pytest.raises(RuntimeError): service_info.request(aiozc.zeroconf, 3000) + await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_browser_instantiation_generates_add_events_from_cache(): + """Test that the ServiceBrowser will generate Add events with the existing cache when starting.""" + + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + type_ = "_hap._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + zc.cache.async_add_records( + [info.dns_pointer(), info.dns_service(), *info.dns_addresses(), info.dns_text()] + ) + + browser = AsyncServiceBrowser(zc, type_, None, listener) + + await asyncio.sleep(0) + + assert callbacks == [ + ('add', type_, registration_name), + ] + await browser.async_cancel() + + await aiozc.async_close() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 53cd6de2..64e44495 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1015,10 +1015,9 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.Recor assert set([record.new for record in updated]) == set([ptr_record, a_record]) - # This behavior is probably wrong but should not be - # changed in this commit because the goal is to refactor - # only and not change how it functions - assert set([record.old for record in updated]) == set([ptr_record, a_record]) + # The old records should be None so we trigger Add events + # in service browsers instead of Update events + assert set([record.old for record in updated]) == set([None]) await aiozc.async_close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index d7900fe9..128d8711 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -414,7 +414,7 @@ def _async_update_matching_records( for question in questions: for record in self.cache.async_entries_with_name(question.name): if not record.is_expired(now) and question.answered_by(record): - records.append(RecordUpdate(record, record)) + records.append(RecordUpdate(record, None)) if not records: return From 9961dce598d3c6eeda68a2f874a7a50ec33f819c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 09:21:37 -1000 Subject: [PATCH 432/608] Update changelog (#798) --- README.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.rst b/README.rst index 5c375399..5751e276 100644 --- a/README.rst +++ b/README.rst @@ -143,6 +143,10 @@ Changelog 0.32.0 Beta 2 (Unreleased) ========================== +* Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) + + This is fixing ServiceBrowser missing an add when the record is already in the cache. + * Pass both the new and old records to async_update_records (#792) @bdraco Pass the old_record (cached) as the value and the new_record (wire) From bbc91241a86f3339aa27cae7b4ea2ab9d7c1f37d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 09:50:05 -1000 Subject: [PATCH 433/608] Ensure we handle threadsafe shutdown under PyPy with multiple event loops (#800) --- tests/test_core.py | 25 +++++++++++++++++++++++++ tests/utils/test_aio.py | 15 +++++++++++++++ zeroconf/_utils/aio.py | 17 +++++++++++------ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 27cfb9e7..ae397d16 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -604,3 +604,28 @@ async def test_open_close_twice_from_async() -> None: zc = Zeroconf(interfaces=['127.0.0.1']) zc.close() zc.close() + await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_multiple_sync_instances_stared_from_async_close(): + """Test we can shutdown multiple sync instances from async.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + zc2 = Zeroconf(interfaces=['127.0.0.1']) + + assert zc.loop == zc2.loop + + zc.close() + assert zc.loop.is_running() + zc2.close() + assert zc2.loop.is_running() + + zc3 = Zeroconf(interfaces=['127.0.0.1']) + assert zc3.loop == zc2.loop + + zc3.close() + assert zc3.loop.is_running() + + await asyncio.sleep(0) diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index bd402d4b..b0fa8dbc 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -6,12 +6,27 @@ import asyncio import contextlib +import unittest.mock import pytest from zeroconf._utils import aio as aioutils +@pytest.mark.asyncio +async def test_async_get_all_tasks() -> None: + """Test we can get all tasks in the event loop. + + We make sure we handle RuntimeError here as + this is not thread safe under PyPy + """ + await aioutils._async_get_all_tasks(aioutils.get_running_loop()) + if not hasattr(asyncio, 'all_tasks'): + return + with unittest.mock.patch("zeroconf._utils.aio.asyncio.all_tasks", side_effect=RuntimeError): + await aioutils._async_get_all_tasks(aioutils.get_running_loop()) + + @pytest.mark.asyncio async def test_get_running_loop_from_async() -> None: """Test we can get the event loop.""" diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index fcf71c34..0b6d8dba 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -62,13 +62,18 @@ def _handle_wait_complete(_: asyncio.Task) -> None: await event_wait -async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]: +async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]: """Return all tasks running.""" await asyncio.sleep(0) # flush out any call_soon_threadsafe - # Make a copy of the tasks in case they change during iteration - if hasattr(asyncio, 'all_tasks'): - return list(asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member - return list(asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member + # If there are multiple event loops running, all_tasks is not + # safe EVEN WHEN CALLED FROM THE EVENTLOOP + # under PyPy so we have to try a few times. + for _ in range(3): + with contextlib.suppress(RuntimeError): + if hasattr(asyncio, 'all_tasks'): + return asyncio.all_tasks(loop) # type: ignore # pylint: disable=no-member + return asyncio.Task.all_tasks(loop) # type: ignore # pylint: disable=no-member + return [] async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: @@ -78,7 +83,7 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" - pending_tasks = set(asyncio.run_coroutine_threadsafe(_get_all_tasks(loop), loop).result()) + pending_tasks = set(asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result()) done_tasks = set(task for task in pending_tasks if not task.done()) pending_tasks -= done_tasks if pending_tasks: From 662ed6166282b9b5b6e83a596c0576a57f8962d2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 09:51:12 -1000 Subject: [PATCH 434/608] Update changelog (#801) --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index 5751e276..bf8f565b 100644 --- a/README.rst +++ b/README.rst @@ -143,6 +143,8 @@ Changelog 0.32.0 Beta 2 (Unreleased) ========================== +* Ensure we handle threadsafe shutdown under PyPy with multiple event loops (#800) @bdraco + * Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) This is fixing ServiceBrowser missing an add when the record is already in the cache. From 58ae3cf553cd925ac90f3db551f4085ea5bc8b79 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 10:04:33 -1000 Subject: [PATCH 435/608] Update changelog (#802) --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index bf8f565b..a4ece3b1 100644 --- a/README.rst +++ b/README.rst @@ -140,12 +140,12 @@ See examples directory for more. Changelog ========= -0.32.0 Beta 2 (Unreleased) -========================== +0.32.0 Beta 2 +============= * Ensure we handle threadsafe shutdown under PyPy with multiple event loops (#800) @bdraco -* Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) +* Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) @bdraco This is fixing ServiceBrowser missing an add when the record is already in the cache. From 18fe341300e28ed93d7b5d7ca8e07edb119bd597 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 11:48:53 -1000 Subject: [PATCH 436/608] Add slots to DNS classes (#803) - On a busy network that receives many mDNS packets per second, we will not know the answer to most of the questions being asked. In this case the creating the DNS* objects are usually garbage collected within 1s as they are not needed. We now set __slots__ to speed up the creation and destruction of these objects --- zeroconf/_dns.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 2c0e3338..af3057bf 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -71,6 +71,8 @@ class DNSEntry: """A DNS entry""" + __slots__ = ('key', 'name', 'type', 'class_', 'unique') + def __init__(self, name: str, type_: int, class_: int) -> None: self.key = name.lower() self.name = name @@ -156,6 +158,8 @@ class DNSRecord(DNSEntry): """A DNS record - like a DNS entry, but has a TTL""" + __slots__ = ('ttl', 'created', '_expiration_time', '_stale_time', '_recent_time') + # TODO: Switch to just int ttl def __init__( self, name: str, type_: int, class_: int, ttl: Union[float, int], created: Optional[float] = None @@ -238,6 +242,8 @@ class DNSAddress(DNSRecord): """A DNS address record""" + __slots__ = ('address',) + def __init__( self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None ) -> None: @@ -274,6 +280,8 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" + __slots__ = ('cpu', 'os') + def __init__( self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None ) -> None: @@ -308,6 +316,8 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" + __slots__ = ('alias',) + def __init__( self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None ) -> None: @@ -345,6 +355,8 @@ class DNSText(DNSRecord): """A DNS text record""" + __slots__ = ('text',) + def __init__( self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None ) -> None: @@ -375,6 +387,8 @@ class DNSService(DNSRecord): """A DNS service record""" + __slots__ = ('priority', 'weight', 'port', 'server') + def __init__( self, name: str, @@ -423,6 +437,8 @@ def __repr__(self) -> str: class DNSRRSet: """A set of dns records independent of the ttl.""" + __slots__ = ('_records', '_lookup') + def __init__(self, records: Iterable[DNSRecord]) -> None: """Create an RRset from records.""" self._records = records From df66da2a943b9ff978602680b746f1edeba048dc Mon Sep 17 00:00:00 2001 From: ZLJasonG <36852337+ZLJasonG@users.noreply.github.com> Date: Tue, 22 Jun 2021 01:16:06 +0100 Subject: [PATCH 437/608] Skip network adapters that are disconnected (#327) Co-authored-by: J. Nick Koston --- zeroconf/_utils/net.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 19500e0f..80a4377b 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -268,6 +268,13 @@ def add_multicast_member( if _errno in err_einval: log.info('Interface of %s does not support multicast, ' 'it is expected in WSL', interface) return False + if _errno == errno.ENOPROTOOPT: + log.info( + 'Failed to set socket option on %s, this can happen if ' + 'the network adapter is in a disconnected state', + interface, + ) + return False raise return True From 59e4bd25347aac254700dc3a1518676042982b3a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 14:17:49 -1000 Subject: [PATCH 438/608] Update changelog (#804) --- README.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.rst b/README.rst index a4ece3b1..ef44bcec 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,20 @@ See examples directory for more. Changelog ========= + +0.32.0 Beta 3 (Unreleased) +========================== + +* Skip network adapters that are disconnected (#327) @ZLJasonG + +* Add slots to DNS classes (#803) @bdraco + + On a busy network that receives many mDNS packets per second, we + will not know the answer to most of the questions being asked. + In this case the creating the DNS* objects are usually garbage + collected within 1s as they are not needed. We now set __slots__ + to speed up the creation and destruction of these objects + 0.32.0 Beta 2 ============= From 5dccf3496a9bd4c268da4c39aab545ddcd50ac57 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 21 Jun 2021 14:21:44 -1000 Subject: [PATCH 439/608] Tag 0.32.0b3 (#805) --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index ef44bcec..02051578 100644 --- a/README.rst +++ b/README.rst @@ -141,8 +141,8 @@ Changelog ========= -0.32.0 Beta 3 (Unreleased) -========================== +0.32.0 Beta 3 +============= * Skip network adapters that are disconnected (#327) @ZLJasonG From 05bb21b9b43f171e30b48fad6a756df49162b557 Mon Sep 17 00:00:00 2001 From: ibygrave Date: Tue, 22 Jun 2021 18:01:10 +0100 Subject: [PATCH 440/608] Qualify IPv6 link-local addresses with scope_id (#343) Co-authored-by: Lokesh Prajapati Co-authored-by: de Angelis, Antonio When a service is advertised on an IPv6 address where the scope is link local, i.e. fe80::/64 (see RFC 4007) the resolved IPv6 address must be extended with the scope_id that identifies through the "%" symbol the local interface to be used when routing to that address. A new API `parsed_scoped_addresses()` is provided to return qualified addresses to avoid breaking compatibility on the existing parsed_addresses(). --- examples/async_browser.py | 2 +- examples/browser.py | 3 ++- tests/services/test_browser.py | 8 +++--- tests/services/test_info.py | 41 +++++++++++++++++++++++------- zeroconf/_core.py | 46 ++++++++++++++++++++++++++-------- zeroconf/_dns.py | 20 ++++++++++++--- zeroconf/_protocol.py | 9 ++++--- zeroconf/_services/info.py | 27 ++++++++++++++++++-- 8 files changed, 121 insertions(+), 35 deletions(-) diff --git a/examples/async_browser.py b/examples/async_browser.py index b835307c..85192e14 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -28,7 +28,7 @@ async def async_display_service_info(zeroconf: Zeroconf, service_type: str, name await info.async_request(zeroconf, 3000) print("Info from zeroconf.get_service_info: %r" % (info)) if info: - addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] print(" Name: %s" % name) print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) diff --git a/examples/browser.py b/examples/browser.py index 2f264439..8525e9b9 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -21,8 +21,9 @@ def on_service_state_change( if state_change is ServiceStateChange.Added: info = zeroconf.get_service_info(service_type, name) print("Info from zeroconf.get_service_info: %r" % (info)) + if info: - addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] + addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) print(" Server: %s" % (info.server,)) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 331e1f0d..e95a7b5f 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -447,10 +447,10 @@ def current_time_millis(): """Current system time in milliseconds""" return start_time + time_offset * 1000 - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): """Sends an outgoing packet.""" got_query.set() - old_send(out, addr=addr, port=port) + old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) # patch the zeroconf send # patch the zeroconf current_time_millis @@ -674,7 +674,7 @@ def current_time_millis(): expected_ttl = const._DNS_HOST_TTL nbr_answers = 0 - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): """Sends an outgoing packet.""" pout = r.DNSIncoming(out.packets()[0]) nonlocal nbr_answers @@ -686,7 +686,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): got_query.set() got_query.clear() - old_send(out, addr=addr, port=port) + old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) # patch the zeroconf send # patch the zeroconf current_time_millis diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 438ad819..d9ca43a7 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -201,6 +201,8 @@ def test_get_info_partial(self): service_server = 'ash-1.local.' service_text = b'path=/~matt1/' service_address = '10.0.1.2' + service_address_v6_ll = 'fe80::52e:c2f2:bc5f:e9c6' + service_scope_id = 12 service_info = None send_event = Event() @@ -208,7 +210,7 @@ def test_get_info_partial(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): """Sends an outgoing packet.""" nonlocal last_sent @@ -316,7 +318,15 @@ def get_service_info_helper(zc, type, name): const._CLASS_IN | const._CLASS_UNIQUE, ttl, socket.inet_pton(socket.AF_INET, service_address), - ) + ), + r.DNSAddress( + service_server, + const._TYPE_AAAA, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + socket.inet_pton(socket.AF_INET6, service_address_v6_ll), + scope_id=service_scope_id, + ), ] ), ) @@ -345,7 +355,7 @@ def test_get_info_single(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): """Sends an outgoing packet.""" nonlocal last_sent @@ -442,6 +452,8 @@ def test_multiple_addresses(): info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address, address]) assert info.addresses == [address, address] + assert info.parsed_addresses() == [address_parsed, address_parsed] + assert info.parsed_scoped_addresses() == [address_parsed, address_parsed] info = ServiceInfo( type_, @@ -454,10 +466,16 @@ def test_multiple_addresses(): parsed_addresses=[address_parsed, address_parsed], ) assert info.addresses == [address, address] + assert info.parsed_addresses() == [address_parsed, address_parsed] + assert info.parsed_scoped_addresses() == [address_parsed, address_parsed] if has_working_ipv6() and not os.environ.get('SKIP_IPV6'): address_v6_parsed = "2001:db8::1" address_v6 = socket.inet_pton(socket.AF_INET6, address_v6_parsed) + address_v6_ll_parsed = "fe80::52e:c2f2:bc5f:e9c6" + address_v6_ll_scoped_parsed = "fe80::52e:c2f2:bc5f:e9c6%12" + address_v6_ll = socket.inet_pton(socket.AF_INET6, address_v6_ll_parsed) + interface_index = 12 infos = [ ServiceInfo( type_, @@ -467,7 +485,8 @@ def test_multiple_addresses(): 0, desc, "ash-2.local.", - addresses=[address, address_v6], + addresses=[address, address_v6, address_v6_ll], + interface_index=interface_index, ), ServiceInfo( type_, @@ -477,17 +496,21 @@ def test_multiple_addresses(): 0, desc, "ash-2.local.", - parsed_addresses=[address_parsed, address_v6_parsed], + parsed_addresses=[address_parsed, address_v6_parsed, address_v6_ll_parsed], + interface_index=interface_index, ), ] for info in infos: assert info.addresses == [address] - assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6] + assert info.addresses_by_version(r.IPVersion.All) == [address, address_v6, address_v6_ll] assert info.addresses_by_version(r.IPVersion.V4Only) == [address] - assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6] - assert info.parsed_addresses() == [address_parsed, address_v6_parsed] + assert info.addresses_by_version(r.IPVersion.V6Only) == [address_v6, address_v6_ll] + assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed] assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed] + assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed] + assert info.parsed_scoped_addresses() == [address_v6_ll_scoped_parsed, address_parsed, address_v6_parsed] + assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed] + assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [address_v6_ll_scoped_parsed, address_v6_parsed] # This test uses asyncio because it needs to access the cache directly diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 2fa093b7..a760f404 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -191,12 +191,16 @@ def datagram_received( self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] ) -> None: assert self.transport is not None + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () if len(addrs) == 2: # https://github.com/python/mypy/issues/1178 addr, port = addrs # type: ignore + scope = None elif len(addrs) == 4: # https://github.com/python/mypy/issues/1178 - addr, port, _flow, _scope = addrs # type: ignore + addr, port, flow, scope = addrs # type: ignore + log.debug('IPv6 scope_id %d associated to the receiving interface', scope) + v6_flow_scope = (flow, scope) else: return @@ -212,7 +216,7 @@ def datagram_received( return self.data = data - msg = DNSIncoming(data) + msg = DNSIncoming(data, scope) if msg.valid: log.debug( 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', @@ -238,7 +242,7 @@ def datagram_received( self.zc.handle_response(msg) return - self.zc.handle_query(msg, addr, port) + self.zc.handle_query(msg, addr, port, v6_flow_scope) def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" @@ -589,7 +593,9 @@ def handle_response(self, msg: DNSIncoming) -> None: are held in the cache, and listeners are notified.""" self.record_manager.async_updates_from_response(msg) - def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: + def handle_query( + self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + ) -> None: """Deal with incoming query packets. Provides a response if possible.""" if not msg.truncated: @@ -606,9 +612,15 @@ def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None: assert self.loop is not None if addr in self._timers: self._timers.pop(addr).cancel() - self._timers[addr] = self.loop.call_later(delay, self._respond_query, None, addr, port) + self._timers[addr] = self.loop.call_later(delay, self._respond_query, None, addr, port, v6_flow_scope) - def _respond_query(self, msg: Optional[DNSIncoming], addr: str, port: int) -> None: + def _respond_query( + self, + msg: Optional[DNSIncoming], + addr: str, + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: """Respond to a query and reassemble any truncated deferred packets.""" if addr in self._timers: self._timers.pop(addr).cancel() @@ -618,16 +630,28 @@ def _respond_query(self, msg: Optional[DNSIncoming], addr: str, port: int) -> No unicast_out, multicast_out = self.query_handler.async_response(packets, addr, port) if unicast_out: - self.async_send(unicast_out, addr, port) + self.async_send(unicast_out, addr, port, v6_flow_scope) if multicast_out: self.async_send(multicast_out, None, _MDNS_PORT) - def send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + def send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: """Sends an outgoing packet threadsafe.""" assert self.loop is not None - self.loop.call_soon_threadsafe(self.async_send, out, addr, port) + self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope) - def async_send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _MDNS_PORT) -> None: + def async_send( + self, + out: DNSOutgoing, + addr: Optional[str] = None, + port: int = _MDNS_PORT, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: """Sends an outgoing packet.""" for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: @@ -653,7 +677,7 @@ def async_send(self, out: DNSOutgoing, addr: Optional[str] = None, port: int = _ continue else: real_addr = addr - transport.sendto(packet, (real_addr, port or _MDNS_PORT)) + transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) except OSError as exc: if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: # with IPv6 we don't have a reliable way to determine if an interface actually has diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index af3057bf..dbe009d2 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -242,13 +242,22 @@ class DNSAddress(DNSRecord): """A DNS address record""" - __slots__ = ('address',) + __slots__ = ('address', 'scope_id') def __init__( - self, name: str, type_: int, class_: int, ttl: int, address: bytes, created: Optional[float] = None + self, + name: str, + type_: int, + class_: int, + ttl: int, + address: bytes, + *, + scope_id: Optional[int] = None, + created: Optional[float] = None, ) -> None: super().__init__(name, type_, class_, ttl, created) self.address = address + self.scope_id = scope_id def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -257,12 +266,15 @@ def write(self, out: 'DNSOutgoing') -> None: def __eq__(self, other: Any) -> bool: """Tests equality on address""" return ( - isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address + isinstance(other, DNSAddress) + and DNSEntry.__eq__(self, other) + and self.address == other.address + and self.scope_id == other.scope_id ) def __hash__(self) -> int: """Hash to compare like DNSAddresses.""" - return hash((*self._entry_tuple(), self.address)) + return hash((*self._entry_tuple(), self.address, self.scope_id)) def __repr__(self) -> str: """String representation""" diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 53a7b710..50bbca28 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -79,7 +79,7 @@ class DNSIncoming(DNSMessage, QuietLogger): """Object representation of an incoming DNS packet""" - def __init__(self, data: bytes) -> None: + def __init__(self, data: bytes, scope_id: Optional[int] = None) -> None: """Constructor from string holding bytes of packet""" super().__init__(0) self.offset = 0 @@ -93,6 +93,7 @@ def __init__(self, data: bytes) -> None: self.num_additionals = 0 self.valid = False self.now = current_time_millis() + self.scope_id = scope_id try: self.read_header() @@ -169,7 +170,7 @@ def read_others(self) -> None: type_, class_, ttl, length = self.unpack(b'!HHiH') rec: Optional[DNSRecord] = None if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), self.now) + rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) elif type_ in (_TYPE_CNAME, _TYPE_PTR): rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) elif type_ == _TYPE_TXT: @@ -197,7 +198,9 @@ def read_others(self) -> None: self.now, ) elif type_ == _TYPE_AAAA: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(16), self.now) + rec = DNSAddress( + domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id + ) else: # Try to ignore types we don't know about # Skip the payload for the resource record so the next diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index f7ab9e55..4365d6ef 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -21,8 +21,9 @@ """ import asyncio +import ipaddress import socket -from typing import Dict, List, Optional, TYPE_CHECKING, Union, cast +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException @@ -85,6 +86,8 @@ class ServiceInfo(RecordUpdateListener): * `other_ttl`: ttl used for PTR/TXT records * `addresses` and `parsed_addresses`: List of IP addresses (either as bytes, network byte order, or in parsed form as text; at most one of those parameters can be provided) + * interface_index: scope_id or zone_id for IPv6 link-local addresses i.e. an identifier of the interface + where the peer is connected to """ text = b'' @@ -103,6 +106,7 @@ def __init__( *, addresses: Optional[List[bytes]] = None, parsed_addresses: Optional[List[str]] = None, + interface_index: Optional[int] = None, ) -> None: # Accept both none, or one, but not both. if addresses is not None and parsed_addresses is not None: @@ -137,6 +141,7 @@ def __init__( self._set_properties(properties) self.host_ttl = host_ttl self.other_ttl = other_ttl + self.interface_index = interface_index @property def name(self) -> str: @@ -194,6 +199,21 @@ def parsed_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: for addr in result ] + def parsed_scoped_addresses(self, version: IPVersion = IPVersion.All) -> List[str]: + """Equivalent to parsed_addresses, with the exception that IPv6 Link-Local + addresses are qualified with % when available + """ + if self.interface_index is None: + return self.parsed_addresses(version) + + def is_link_local(addr_str: str) -> Any: + addr = ipaddress.ip_address(addr_str) + return addr.version == 6 and addr.is_link_local + + ll_addrs = list(filter(is_link_local, self.parsed_addresses(version))) + other_addrs = list(filter(lambda addr: not is_link_local(addr), self.parsed_addresses(version))) + return ["{}%{}".format(addr, self.interface_index) for addr in ll_addrs] + other_addrs + def _set_properties(self, properties: Dict) -> None: """Sets properties and text of this info from a dictionary""" self._properties = properties @@ -289,6 +309,8 @@ def _process_record_threadsafe(self, record: DNSRecord, now: float) -> None: if isinstance(record, DNSAddress): if record.key == self.server_key and record.address not in self._addresses: self._addresses.append(record.address) + if record.type is _TYPE_AAAA and ipaddress.IPv6Address(record.address).is_link_local: + self.interface_index = record.scope_id return if isinstance(record, DNSService): @@ -320,7 +342,7 @@ def dns_addresses( _CLASS_IN | _CLASS_UNIQUE, override_ttl if override_ttl is not None else self.host_ttl, address, - created, + created=created, ) for address in self.addresses_by_version(version) ] @@ -474,6 +496,7 @@ def __repr__(self) -> str: 'priority', 'server', 'properties', + 'interface_index', ) ), ) From 0129ac061db4a950f7bddf1084309e44aaabdbdf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 07:28:40 -1000 Subject: [PATCH 441/608] Format tests/services/test_info.py with newer black (#809) --- tests/services/test_info.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index d9ca43a7..e55f03ce 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -508,9 +508,16 @@ def test_multiple_addresses(): assert info.parsed_addresses() == [address_parsed, address_v6_parsed, address_v6_ll_parsed] assert info.parsed_addresses(r.IPVersion.V4Only) == [address_parsed] assert info.parsed_addresses(r.IPVersion.V6Only) == [address_v6_parsed, address_v6_ll_parsed] - assert info.parsed_scoped_addresses() == [address_v6_ll_scoped_parsed, address_parsed, address_v6_parsed] + assert info.parsed_scoped_addresses() == [ + address_v6_ll_scoped_parsed, + address_parsed, + address_v6_parsed, + ] assert info.parsed_scoped_addresses(r.IPVersion.V4Only) == [address_parsed] - assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [address_v6_ll_scoped_parsed, address_v6_parsed] + assert info.parsed_scoped_addresses(r.IPVersion.V6Only) == [ + address_v6_ll_scoped_parsed, + address_v6_parsed, + ] # This test uses asyncio because it needs to access the cache directly From f9bbbce388f2c6c24109c15ef843c10eeccf008f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 07:46:10 -1000 Subject: [PATCH 442/608] Make DNSHinfo and DNSAddress use the same match order as DNSPointer and DNSText (#808) We want to check the data that is most likely to be unique first so we can reject the __eq__ as soon as possible. --- zeroconf/_dns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index dbe009d2..93db9859 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -267,9 +267,9 @@ def __eq__(self, other: Any) -> bool: """Tests equality on address""" return ( isinstance(other, DNSAddress) - and DNSEntry.__eq__(self, other) and self.address == other.address and self.scope_id == other.scope_id + and DNSEntry.__eq__(self, other) ) def __hash__(self) -> int: @@ -310,9 +310,9 @@ def __eq__(self, other: Any) -> bool: """Tests equality on cpu and os""" return ( isinstance(other, DNSHinfo) - and DNSEntry.__eq__(self, other) and self.cpu == other.cpu and self.os == other.os + and DNSEntry.__eq__(self, other) ) def __hash__(self) -> int: From d4c8f0d3ffdcdc609810aca383492a57f9e1a723 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 07:53:40 -1000 Subject: [PATCH 443/608] Simplify wait_event_or_timeout (#810) - This function always did the same thing on timeout and wait complete so we can use the same callback. This solves the CI failing due to the test coverage flapping back and forth as the timeout would rarely happen. --- zeroconf/_utils/aio.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index 0b6d8dba..7cc3b7fa 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -23,7 +23,7 @@ import asyncio import contextlib import queue -from typing import List, Optional, Set, cast +from typing import Any, List, Optional, Set, cast def get_best_available_queue() -> queue.Queue: @@ -39,18 +39,13 @@ async def wait_event_or_timeout(event: asyncio.Event, timeout: float) -> None: loop = asyncio.get_event_loop() future = loop.create_future() - def _handle_timeout() -> None: + def _handle_timeout_or_wait_complete(*_: Any) -> None: if not future.done(): future.set_result(None) - timer_handle = loop.call_later(timeout, _handle_timeout) + timer_handle = loop.call_later(timeout, _handle_timeout_or_wait_complete) event_wait = loop.create_task(event.wait()) - - def _handle_wait_complete(_: asyncio.Task) -> None: - if not future.done(): - future.set_result(None) - - event_wait.add_done_callback(_handle_wait_complete) + event_wait.add_done_callback(_handle_timeout_or_wait_complete) try: await future From 13c558cf3f40e52a13347a39b050e49a9241c269 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 07:56:36 -1000 Subject: [PATCH 444/608] Update changelog (#811) --- README.rst | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.rst b/README.rst index 02051578..1fd39c7f 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,31 @@ See examples directory for more. Changelog ========= +0.32.0 Beta 4 +============= + +* Simplify wait_event_or_timeout (#810) @bdraco + + This function always did the same thing on timeout and + wait complete so we can use the same callback. This + solves the CI failing due to the test coverage flapping + back and forth as the timeout would rarely happen. + +* Make DNSHinfo and DNSAddress use the same match order as DNSPointer and DNSText (#808) @bdraco + + We want to check the data that is most likely to be unique first + so we can reject the __eq__ as soon as possible. + +* Qualify IPv6 link-local addresses with scope_id (#343) @ibygrave + + When a service is advertised on an IPv6 address where + the scope is link local, i.e. fe80::/64 (see RFC 4007) + the resolved IPv6 address must be extended with the + scope_id that identifies through the "%" symbol the + local interface to be used when routing to that address. + A new API `parsed_scoped_addresses()` is provided to + return qualified addresses to avoid breaking compatibility + on the existing parsed_addresses(). 0.32.0 Beta 3 ============= From e32bb5d98be0dc7ed130224206a4de699bcd68e3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 13:24:56 -1000 Subject: [PATCH 445/608] New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) --- tests/services/test_browser.py | 19 ++++++++++++------- zeroconf/_services/browser.py | 26 +++++++++++++++++++------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index e95a7b5f..b78e0c62 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -544,8 +544,8 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser.close() -def test_asking_default_is_asking_qm_questions(): - """Verify the service browser can ask QU questions.""" +def test_asking_default_is_asking_qm_questions_after_the_first_qu(): + """Verify the service browser's first question is QU and subsequent ones are QM questions.""" type_ = "_quservice._tcp.local." zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) @@ -553,10 +553,14 @@ def test_asking_default_is_asking_qm_questions(): old_send = zeroconf_browser.async_send first_outgoing = None + second_outgoing = None def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" nonlocal first_outgoing + nonlocal second_outgoing + if first_outgoing is not None and second_outgoing is None: + second_outgoing = out if first_outgoing is None: first_outgoing = out old_send(out, addr=addr, port=port) @@ -567,10 +571,11 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): def on_service_state_change(zeroconf, service_type, state_change, name): pass - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 5)) + browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change], delay=5) + time.sleep(millis_to_seconds(_services_browser._FIRST_QUERY_DELAY_RANDOM_INTERVAL[1] + 120 + 5)) try: - assert first_outgoing.questions[0].unicast == False + assert first_outgoing.questions[0].unicast == True + assert second_outgoing.questions[0].unicast == False finally: browser.cancel() zeroconf_browser.close() @@ -1016,7 +1021,7 @@ async def test_generate_service_query_suppress_duplicate_questions(): aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) zc = aiozc.zeroconf now = current_time_millis() - name = "_hap._tcp.local." + name = "_suppresstest._tcp.local." question = r.DNSQuestion(name, const._TYPE_PTR, const._CLASS_IN) answer = r.DNSPointer( name, @@ -1048,7 +1053,7 @@ async def test_generate_service_query_suppress_duplicate_questions(): outs = _services_browser.generate_service_query(zc, now, [name], multicast=False) assert outs - zc.question_history.async_expire(now + 1000) + zc.question_history.async_expire(now + 2000) # No suppression after clearing the history outs = _services_browser.generate_service_query(zc, now, [name], multicast=True) assert outs diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 1a7caca8..40a80df3 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -146,11 +146,14 @@ def generate_service_query( for record in zc.cache.get_all_by_details(type_, _TYPE_PTR, _CLASS_IN) if not record.is_stale(now) ) - if multicast and zc.question_history.suppresses(question, now, cast(Set[DNSRecord], known_answers)): + if not qu_question and zc.question_history.suppresses( + question, now, cast(Set[DNSRecord], known_answers) + ): log.debug("Asking %s was suppressed by the question history", question) continue questions_with_known_answers[question] = known_answers - zc.question_history.add_question_at_time(question, now, cast(Set[DNSRecord], known_answers)) + if not qu_question: + zc.question_history.add_question_at_time(question, now, cast(Set[DNSRecord], known_answers)) return _group_ptr_queries_with_known_answers(now, multicast, questions_with_known_answers) @@ -379,7 +382,7 @@ def _async_cancel(self) -> None: self.done = True self.zc.async_remove_listener(self) - def generate_ready_queries(self) -> List[DNSOutgoing]: + def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" now = current_time_millis() if self._millis_to_wait(current_time_millis()): @@ -395,7 +398,13 @@ def generate_ready_queries(self) -> List[DNSOutgoing]: self._next_time[type_] = now + self._delay[type_] self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - return generate_service_query(self.zc, now, ready_types, self.multicast, self.question_type) + # If they did not specify and this is the first request, ask QU questions + # https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are + # just starting up and we know our cache is likely empty. This ensures + # the next outgoing will be sent with the known answers list. + question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type + + return generate_service_query(self.zc, now, ready_types, self.multicast, question_type) def _millis_to_wait(self, now: float) -> Optional[float]: """Returns the number of milliseconds to wait for the next event.""" @@ -406,14 +415,17 @@ def _millis_to_wait(self, now: float) -> Optional[float]: async def async_browser_task(self) -> None: """Run the browser task.""" await self.zc.async_wait_for_start() + first_request = True while True: timeout = self._millis_to_wait(current_time_millis()) if timeout: await self.zc.async_wait(timeout) - outs = self.generate_ready_queries() - for out in outs: - self.zc.async_send(out, addr=self.addr, port=self.port) + outs = self._generate_ready_queries(first_request) + if outs: + first_request = False + for out in outs: + self.zc.async_send(out, addr=self.addr, port=self.port) async def _async_cancel_browser(self) -> None: """Cancel the browser.""" From ffd2532f72a59ede86732b310512774b8fa344e7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 13:43:20 -1000 Subject: [PATCH 446/608] Turn on logging in the types test (#816) - Will be needed to track down #813 --- tests/services/test_types.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index ba355bae..7a07085f 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -4,17 +4,31 @@ """Unit tests for zeroconf._services.types.""" +import logging import os import unittest import socket import sys -import time import zeroconf as r from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes from .. import _clear_cache, has_working_ipv6 +log = logging.getLogger('zeroconf') +original_logging_level = logging.NOTSET + + +def setup_module(): + global original_logging_level + original_logging_level = log.level + log.setLevel(logging.DEBUG) + + +def teardown_module(): + if original_logging_level != logging.NOTSET: + log.setLevel(original_logging_level) + class ServiceTypesQuery(unittest.TestCase): def test_integration_with_listener(self): From f9d35299a39fee0b1632a3b2ac00170f761d53b1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 13:58:10 -1000 Subject: [PATCH 447/608] Fix default v6_flow_scope argument with tests that mock send (#819) --- tests/services/test_browser.py | 4 ++-- tests/services/test_info.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index b78e0c62..688ad18b 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -447,7 +447,7 @@ def current_time_millis(): """Current system time in milliseconds""" return start_time + time_offset * 1000 - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): """Sends an outgoing packet.""" got_query.set() old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) @@ -679,7 +679,7 @@ def current_time_millis(): expected_ttl = const._DNS_HOST_TTL nbr_answers = 0 - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): """Sends an outgoing packet.""" pout = r.DNSIncoming(out.packets()[0]) nonlocal nbr_answers diff --git a/tests/services/test_info.py b/tests/services/test_info.py index e55f03ce..348d623d 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -210,7 +210,7 @@ def test_get_info_partial(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): """Sends an outgoing packet.""" nonlocal last_sent @@ -355,7 +355,7 @@ def test_get_info_single(self): last_sent = None # type: Optional[r.DNSOutgoing] - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=None): + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): """Sends an outgoing packet.""" nonlocal last_sent From a7b4f8e070de69db1ed872e2ff7a953ec624394c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 16:48:02 -1000 Subject: [PATCH 448/608] Fix reliablity of tests that patch sending (#820) --- tests/__init__.py | 6 ++++++ tests/services/test_browser.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/__init__.py b/tests/__init__.py index 0e7aa930..d77140fd 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -41,6 +41,12 @@ async def _wait_for_response(): asyncio.run_coroutine_threadsafe(_wait_for_response(), zc.loop).result() +def _wait_for_start(zc: Zeroconf) -> None: + """Wait for all sockets to be up and running.""" + assert zc.loop is not None + asyncio.run_coroutine_threadsafe(zc.async_wait_for_start(), zc.loop).result() + + @lru_cache(maxsize=None) def has_working_ipv6(): """Return True if if the system can bind an IPv6 address.""" diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 688ad18b..35fe487f 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -22,7 +22,7 @@ from zeroconf._services.info import ServiceInfo from zeroconf.aio import AsyncZeroconf -from .. import has_working_ipv6, _inject_response +from .. import has_working_ipv6, _inject_response, _wait_for_start log = logging.getLogger('zeroconf') @@ -435,6 +435,7 @@ def test_backoff(suppresses_mock): type_ = "_http._tcp.local." zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zeroconf_browser) # we are going to patch the zeroconf send to check query transmission old_send = zeroconf_browser.async_send @@ -513,6 +514,7 @@ def test_first_query_delay(): """ type_ = "_http._tcp.local." zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zeroconf_browser) # we are going to patch the zeroconf send to check query transmission old_send = zeroconf_browser.async_send @@ -666,6 +668,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): service_removed.set() zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zeroconf_browser) # we are going to patch the zeroconf send to check packet sizes old_send = zeroconf_browser.async_send From 4062fe21d8baaad36960f8cae0f59ac7083a6b55 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 17:06:49 -1000 Subject: [PATCH 449/608] Only wake up the query loop when there is a change in the next query time (#818) The ServiceBrowser query loop (async_browser_task) was being awoken on every packet because it was using `zeroconf.async_wait` which wakes up on every new packet. We only need to awaken the loop when the next time we are going to send a query has changed. fixes #814 fixes #768 --- tests/services/test_browser.py | 97 ++++++++++++++++++++-- zeroconf/_services/browser.py | 145 ++++++++++++++++++++++----------- 2 files changed, 187 insertions(+), 55 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 35fe487f..eac26a4c 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._services.browser. """ +import asyncio import logging import socket import time @@ -476,13 +477,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): expected_query_time = 0.0 while True: sleep_count += 1 - for _ in range(2): - # If the browser thread is starting up - # its possible we notify before the initial sleep - # which means the test will fail so we need to d - # this twice to eliminate the race condition - zeroconf_browser.notify_all() - got_query.wait(0.05) + got_query.wait(0.1) if time_offset == expected_query_time: assert got_query.is_set() got_query.clear() @@ -501,6 +496,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): else: assert not got_query.is_set() time_offset += initial_query_interval + zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed) finally: browser.cancel() @@ -726,7 +722,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): while nbr_answers < test_iterations: # Increase simulated time shift by 1/4 of the TTL in seconds time_offset += expected_ttl / 4 - zeroconf_browser.notify_all() + zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed) sleep_count += 1 got_query.wait(0.5) # Prevent the test running indefinitely in an error condition @@ -1067,3 +1063,88 @@ async def test_generate_service_query_suppress_duplicate_questions(): assert not outs await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_query_scheduler(): + delay = const._BROWSER_TIME + types_ = set(["_hap._tcp.local.", "_http._tcp.local."]) + query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0)) + + now = current_time_millis() + query_scheduler.start(now) + + # Test query interval is increasing + assert query_scheduler.millis_to_wait(now - 1) == 1 + assert query_scheduler.millis_to_wait(now) is None + assert query_scheduler.millis_to_wait(now + 1) is None + + assert set(query_scheduler.process_ready_types(now)) == types_ + assert set(query_scheduler.process_ready_types(now)) == set() + assert query_scheduler.millis_to_wait(now) == delay + + assert set(query_scheduler.process_ready_types(now + delay)) == types_ + assert set(query_scheduler.process_ready_types(now + delay)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 3 + + assert set(query_scheduler.process_ready_types(now + delay * 3)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 3)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 7 + + assert set(query_scheduler.process_ready_types(now + delay * 7)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 7)) == set() + assert query_scheduler.millis_to_wait(now) == delay * 15 + + assert set(query_scheduler.process_ready_types(now + delay * 15)) == types_ + assert set(query_scheduler.process_ready_types(now + delay * 15)) == set() + + # Test if we reschedule 1 second later, the millis_to_wait goes up by 1 + query_scheduler.reschedule_type("_hap._tcp.local.", now + delay * 16) + assert query_scheduler.millis_to_wait(now) == delay * 16 + + assert set(query_scheduler.process_ready_types(now + delay * 15)) == set() + + # Test if we reschedule 1 second later... and its ready for processing + assert set(query_scheduler.process_ready_types(now + delay * 16)) == set(["_hap._tcp.local."]) + assert query_scheduler.millis_to_wait(now) == delay * 31 + assert set(query_scheduler.process_ready_types(now + delay * 20)) == set() + + assert set(query_scheduler.process_ready_types(now + delay * 31)) == set(["_http._tcp.local."]) + + +@pytest.mark.asyncio +async def test_query_scheduler_triggers_async_wait_ready_on_reschedule(): + """Test that a reschedule wakes up the async_wait_ready.""" + delay = const._BROWSER_TIME + types_ = set(["_hap._tcp.local.", "_http._tcp.local."]) + query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0)) + + now = current_time_millis() + query_scheduler.start(now) + assert set(query_scheduler.process_ready_types(now)) == types_ + assert query_scheduler.millis_to_wait(now) == delay + + task = asyncio.ensure_future(query_scheduler.async_wait_ready(now)) + await asyncio.sleep(0) # Start the task + await asyncio.sleep(0) # Make sure its waiting + assert not task.done() + assert query_scheduler.millis_to_wait(now + 1) == delay - 1 + query_scheduler.reschedule_type("_hap._tcp.local.", now + 1) + assert query_scheduler.millis_to_wait(now + 1) is None + await asyncio.wait_for(task, timeout=0.1) + assert task.done() + + task2 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000)) + assert set(query_scheduler.process_ready_types(now + 1)) == set(["_hap._tcp.local."]) + assert not task2.done() + assert query_scheduler.millis_to_wait(now + 2) == delay - 2 + query_scheduler.reschedule_type("_hap._tcp.local.", now + 2) + assert query_scheduler.millis_to_wait(now + 2) is None + await asyncio.wait_for(task2, timeout=0.1) + assert task2.done() + assert set(query_scheduler.process_ready_types(now + 10000)) == types_ + assert query_scheduler.millis_to_wait(now + 10000) == delay * 2 + + task3 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000)) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(task3, timeout=0.1) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 40a80df3..a7abca4f 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -39,7 +39,7 @@ SignalRegistrationInterface, ) from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.aio import get_best_available_queue +from .._utils.aio import get_best_available_queue, wait_event_or_timeout from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( @@ -183,6 +183,89 @@ def on_change( return on_change +class QueryScheduler: + """Schedule outgoing PTR queries for Continuous Multicast DNS Querying + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + + """ + + def __init__( + self, + types: Set[str], + delay: int, + first_random_delay_interval: Tuple[int, int], + ): + self._schedule_changed_event: Optional[asyncio.Event] = None + self._types = types + self._next_time: Dict[str, float] = {} + self._first_random_delay_interval = first_random_delay_interval + self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self._types} + + def start(self, now: float) -> None: + """Start the scheduler.""" + self._schedule_changed_event = asyncio.Event() + self._generate_first_next_time(now) + + def _generate_first_next_time(self, now: float) -> None: + """Generate the initial next query times. + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 + To avoid accidental synchronization when, for some reason, multiple + clients begin querying at exactly the same moment (e.g., because of + some common external trigger event), a Multicast DNS querier SHOULD + also delay the first query of the series by a randomly chosen amount + in the range 20-120 ms. + """ + delay = millis_to_seconds(random.randint(*self._first_random_delay_interval)) + next_time = now + delay + self._next_time = {check_type_: next_time for check_type_ in self._types} + + def millis_to_wait(self, now: float) -> Optional[float]: + """Returns the number of milliseconds to wait for the next event.""" + # Wait for the type has the smallest next time + next_time = min(self._next_time.values()) + return None if next_time <= now else next_time - now + + def reschedule_type(self, type_: str, next_time: float) -> None: + """Reschedule the query for a type to happen sooner.""" + if next_time >= self._next_time[type_]: + return + + self._next_time[type_] = next_time + self.set_schedule_changed() + + def set_schedule_changed(self) -> None: + """Set the event to unblock async_wait_ready to make sure the adjusted next time is seen.""" + assert self._schedule_changed_event is not None + self._schedule_changed_event.set() + self._schedule_changed_event.clear() + + def process_ready_types(self, now: float) -> List[str]: + """Generate a list of ready types that is due and schedule the next time.""" + if self.millis_to_wait(now): + return [] + + ready_types: List[str] = [] + + for type_, due in self._next_time.items(): + if due > now: + continue + + ready_types.append(type_) + self._next_time[type_] = now + self._delay[type_] + self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) + + return ready_types + + async def async_wait_ready(self, now: float) -> None: + """Wait for at least one query to be ready.""" + timeout = self.millis_to_wait(now) + if timeout: + assert self._schedule_changed_event is not None + await wait_event_or_timeout(self._schedule_changed_event, timeout=millis_to_seconds(timeout)) + + class _ServiceBrowserBase(RecordUpdateListener): """Base class for ServiceBrowser.""" @@ -225,10 +308,9 @@ def __init__( self.port = port self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6) self.question_type = question_type - self._next_time: Dict[str, float] = {} - self._delay: Dict[str, float] = {check_type_: delay for check_type_ in self.types} self._pending_handlers: OrderedDict[Tuple[str, str], ServiceStateChange] = OrderedDict() self._service_state_changed = Signal() + self.query_scheduler = QueryScheduler(self.types, delay, _FIRST_QUERY_DELAY_RANDOM_INTERVAL) self.queue: Optional[queue.Queue] = None self.done = False @@ -250,25 +332,11 @@ def _async_start(self) -> None: Must be called by uses of this base class after they have finished setting their properties. """ - self._generate_first_next_time() + self.query_scheduler.start(current_time_millis()) self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) # Only start queries after the listener is installed self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) - def _generate_first_next_time(self) -> None: - """Generate the initial next query times. - - https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 - To avoid accidental synchronization when, for some reason, multiple - clients begin querying at exactly the same moment (e.g., because of - some common external trigger event), a Multicast DNS querier SHOULD - also delay the first query of the series by a randomly chosen amount - in the range 20-120 ms. - """ - delay = millis_to_seconds(random.randint(*_FIRST_QUERY_DELAY_RANDOM_INTERVAL)) - next_time = current_time_millis() + delay - self._next_time = {check_type_: next_time for check_type_ in self.types} - @property def service_state_changed(self) -> SignalRegistrationInterface: return self._service_state_changed.registration_interface @@ -310,9 +378,9 @@ def _async_process_record_update( elif expired: self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: - expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) - if expires < self._next_time[record.name]: - self._next_time[record.name] = expires + self.query_scheduler.reschedule_type( + record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) + ) return # If its expired or already exists in the cache it cannot be updated. @@ -385,47 +453,30 @@ def _async_cancel(self) -> None: def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: """Generate the service browser query for any type that is due.""" now = current_time_millis() - if self._millis_to_wait(current_time_millis()): + ready_types = self.query_scheduler.process_ready_types(now) + if not ready_types: return [] - ready_types = [] - - for type_, due in self._next_time.items(): - if due > now: - continue - - ready_types.append(type_) - self._next_time[type_] = now + self._delay[type_] - self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2) - # If they did not specify and this is the first request, ask QU questions # https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 since we are # just starting up and we know our cache is likely empty. This ensures # the next outgoing will be sent with the known answers list. question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type - return generate_service_query(self.zc, now, ready_types, self.multicast, question_type) - def _millis_to_wait(self, now: float) -> Optional[float]: - """Returns the number of milliseconds to wait for the next event.""" - # Wait for the type has the smallest next time - next_time = min(self._next_time.values()) - return None if next_time <= now else next_time - now - async def async_browser_task(self) -> None: """Run the browser task.""" await self.zc.async_wait_for_start() first_request = True while True: - timeout = self._millis_to_wait(current_time_millis()) - if timeout: - await self.zc.async_wait(timeout) - + await self.query_scheduler.async_wait_ready(current_time_millis()) outs = self._generate_ready_queries(first_request) - if outs: - first_request = False - for out in outs: - self.zc.async_send(out, addr=self.addr, port=self.port) + if not outs: + continue + + first_request = False + for out in outs: + self.zc.async_send(out, addr=self.addr, port=self.port) async def _async_cancel_browser(self) -> None: """Cancel the browser.""" From 4a8276941a07188180ee31dc4ca578306c2df92b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 22 Jun 2021 18:04:35 -1000 Subject: [PATCH 450/608] Update changelog (#822) --- README.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.rst b/README.rst index 1fd39c7f..9f45c35f 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,28 @@ See examples directory for more. Changelog ========= +0.32.0 Beta 5 +============= + +* Only wake up the query loop when there is a change in the next query time (#818) @bdraco + + The ServiceBrowser query loop (async_browser_task) was being awoken on + every packet because it was using `zeroconf.async_wait` which wakes + up on every new packet. We only need to awaken the loop when the next time + we are going to send a query has changed. + +* New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco + + https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 + When we start a ServiceBrowser and zeroconf has just started up, the known + answer list will be small. By asking a QU question first, it is likely + that we have a large known answer list by the time we ask the QM question + a second later (current default which is likely too low but would be + a breaking change to increase). This reduces the amount of traffic on + the network, and has the secondary advantage that most responders will + answer a QU question without the typical delay answering QM questions. + + 0.32.0 Beta 4 ============= From 7f6d003210244b6f7df133bd474d7ddf64098422 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 23 Jun 2021 09:16:51 -1000 Subject: [PATCH 451/608] Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) * We now enforce a minimum TTL for PTR records to avoid ServiceBrowsers generating excessive queries refresh queries. Apple uses a 15s minimum TTL, however we do not have the same level of rate limit and safe guards so we use 1/4 of the recommended value. --- tests/test_core.py | 4 ++-- tests/test_handlers.py | 47 ++++++++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 20 ++++++++++++++++++ zeroconf/const.py | 5 +++++ 4 files changed, 74 insertions(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index ae397d16..0a07bd51 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -367,8 +367,8 @@ def test_register_service_with_custom_ttl(): addresses=[socket.inet_aton("10.0.1.2")], ) - zc.register_service(info_service, ttl=30) - assert zc.cache.get(info_service.dns_pointer()).ttl == 30 + zc.register_service(info_service, ttl=3000) + assert zc.cache.get(info_service.dns_pointer()).ttl == 3000 zc.close() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 64e44495..ea29c528 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1070,3 +1070,50 @@ def test_questions_query_handler_does_not_put_qu_questions_in_history(): assert not zc.question_history.suppresses(question, now, set([known_answer])) zc.close() + + +def test_guard_against_low_ptr_ttl(): + """Ensure we enforce a minimum for PTR record ttls to avoid excessive refresh queries from ServiceBrowsers. + + Some poorly designed IoT devices can set excessively low PTR + TTLs would will cause ServiceBrowsers to flood the network + with excessive refresh queries. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + # Apple uses a 15s minimum TTL, however we do not have the same + # level of rate limit and safe guards so we use 1/4 of the recommended value + answer_with_low_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 2, + 'low.local.', + ) + answer_with_normal_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'normal.local.', + ) + good_bye_answer = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 0, + 'goodbye.local.', + ) + # TTL should be adjusted to a safe value + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(answer_with_low_ttl, 0) + response.add_answer_at_time(answer_with_normal_ttl, 0) + response.add_answer_at_time(good_bye_answer, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + + incoming_answer_low = zc.cache.async_get_unique(answer_with_low_ttl) + assert incoming_answer_low.ttl == const._DNS_PTR_MIN_TTL + incoming_answer_normal = zc.cache.async_get_unique(answer_with_normal_ttl) + assert incoming_answer_normal.ttl == const._DNS_OTHER_TTL + assert zc.cache.async_get_unique(good_bye_answer) is None + zc.close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 128d8711..617be408 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -34,6 +34,7 @@ from .const import ( _CLASS_IN, _DNS_OTHER_TTL, + _DNS_PTR_MIN_TTL, _FLAGS_AA, _FLAGS_QR_RESPONSE, _MDNS_PORT, @@ -54,6 +55,23 @@ _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] +def sanitize_incoming_record(record: DNSRecord) -> None: + """Protect zeroconf from records that can cause denial of service. + + We enforce a minimum TTL for PTR records to avoid + ServiceBrowsers generating excessive queries refresh queries. + Apple uses a 15s minimum TTL, however we do not have the same + level of rate limit and safe guards so we use 1/4 of the recommended value. + """ + if record.ttl and record.ttl < _DNS_PTR_MIN_TTL and isinstance(record, DNSPointer): + log.debug( + "Increasing effective ttl of %s to minimum of %s to protect against excessive refreshes.", + record, + _DNS_PTR_MIN_TTL, + ) + record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL) + + class _QueryResponse: """A pair for unicast and multicast DNSOutgoing responses.""" @@ -321,6 +339,8 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: unique_types: Set[Tuple[str, int, int]] = set() for record in msg.answers: + sanitize_incoming_record(record) + if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2 unique_types.add((record.name, record.type, record.class_)) diff --git a/zeroconf/const.py b/zeroconf/const.py index df1ba8be..8107a2af 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -47,6 +47,11 @@ _DNS_PORT = 53 _DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 _DNS_OTHER_TTL = 4500 # 75 minutes for non-host records (PTR, TXT etc) as-per RFC6762 +# Currently we enforce a minimum TTL for PTR records to avoid +# ServiceBrowsers generating excessive queries refresh queries. +# Apple uses a 15s minimum TTL, however we do not have the same +# level of rate limit and safe guards so we use 1/4 of the recommended value +_DNS_PTR_MIN_TTL = _DNS_OTHER_TTL / 4 _DNS_PACKET_HEADER_LEN = 12 From 6298ef9078cf2408bc1e57660ee141e882d13469 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 23 Jun 2021 09:19:13 -1000 Subject: [PATCH 452/608] Drop oversize packets before processing them (#826) - Oversized packets can quickly overwhelm the system and deny service to legitimate queriers. In practice this is usually due to broken mDNS implementations rather than malicious actors. --- tests/test_core.py | 70 +++++++++++++++++++++++++++++++++++++++++++++- zeroconf/_core.py | 12 ++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 0a07bd51..9f2412f0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -17,7 +17,7 @@ from typing import cast import zeroconf as r -from zeroconf import _core, const, ServiceBrowser, Zeroconf, current_time_millis +from zeroconf import _core, _protocol, const, ServiceBrowser, Zeroconf, current_time_millis from zeroconf.aio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response @@ -629,3 +629,71 @@ async def test_multiple_sync_instances_stared_from_async_close(): assert zc3.loop.is_running() await asyncio.sleep(0) + + +def test_guard_against_oversized_packets(): + """Ensure we do not process oversized packets. + + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + + for i in range(5000): + generated.add_answer_at_time( + r.DNSText( + "packet{i}.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ), + 0, + ) + + # We are patching to generate an oversized packet + with unittest.mock.patch.object(_protocol, "_MAX_MSG_ABSOLUTE", 100000), unittest.mock.patch.object( + _protocol, "_MAX_MSG_TYPICAL", 100000 + ): + over_sized_packet = generated.packets()[0] + assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + okpacket_record = r.DNSText( + "okpacket.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + + generated.add_answer_at_time( + okpacket_record, + 0, + ) + ok_packet = generated.packets()[0] + + # We cannot test though the network interface as some operating systems + # will guard against the oversized packet and we won't see it. + listener = _core.AsyncListener(zc) + listener.transport = unittest.mock.MagicMock() + + listener.datagram_received(ok_packet, ('127.0.0.1', 5353)) + assert zc.cache.async_get_unique(okpacket_record) is not None + + listener.datagram_received(over_sized_packet, ('127.0.0.1', 5353)) + assert ( + zc.cache.async_get_unique( + r.DNSText( + "packet0.local.", + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + 500, + b'path=/~paulsm/', + ) + ) + is None + ) + + zc.close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a760f404..8237eb16 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -216,6 +216,18 @@ def datagram_received( return self.data = data + + if len(data) > _MAX_MSG_ABSOLUTE: + # Guard against oversized packets to ensure bad implementations cannot overwhelm + # the system. + log.debug( + "Discarding incoming packet with length %s, which is larger " + "than the absolute maximum size of %s", + len(data), + _MAX_MSG_ABSOLUTE, + ) + return + msg = DNSIncoming(data, scope) if msg.valid: log.debug( From 82f80c301a6324d2f1711ca751e81069e90030ec Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 23 Jun 2021 09:22:15 -1000 Subject: [PATCH 453/608] Update changelog (#827) --- README.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.rst b/README.rst index 9f45c35f..61c05e98 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,26 @@ See examples directory for more. Changelog ========= +0.32.0 Beta 6 +============= + +This beta addresses two potential areas where zeroconf can be overwhelmed and +deny service to legitimate queriers. + +* BREAKING CHANGE: Drop oversize packets before processing them (#826) @bdraco + + Oversized packets can quickly overwhelm the system and deny + service to legitimate queriers. In practice this is usually + due to broken mDNS implementations rather than malicious + actors. + +* BREAKING CHANGE: Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) @bdraco + + We now enforce a minimum TTL for PTR records to avoid + ServiceBrowsers generating excessive queries refresh queries. + Apple uses a 15s minimum TTL, however we do not have the same + level of rate limit and safe guards so we use 1/4 of the recommended value. + 0.32.0 Beta 5 ============= From 4c4b388ba125ad23a03722b30c71da86853fe05a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 17:52:46 -1000 Subject: [PATCH 454/608] Convert test_integration to asyncio to avoid testing threading races (#828) Fixes #768 --- tests/services/test_browser.py | 92 -------------------------------- tests/test_aio.py | 97 +++++++++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 93 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index eac26a4c..eb25c7e2 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -647,98 +647,6 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zeroconf_browser.close() -def test_integration(): - service_added = Event() - service_removed = Event() - unexpected_ttl = Event() - got_query = Event() - - type_ = "_http._tcp.local." - registration_name = "xxxyyy.%s" % type_ - - def on_service_state_change(zeroconf, service_type, state_change, name): - if name == registration_name: - if state_change is ServiceStateChange.Added: - service_added.set() - elif state_change is ServiceStateChange.Removed: - service_removed.set() - - zeroconf_browser = Zeroconf(interfaces=['127.0.0.1']) - _wait_for_start(zeroconf_browser) - - # we are going to patch the zeroconf send to check packet sizes - old_send = zeroconf_browser.async_send - - time_offset = 0.0 - - def current_time_millis(): - """Current system time in milliseconds""" - return time.time() * 1000 + time_offset * 1000 - - expected_ttl = const._DNS_HOST_TTL - nbr_answers = 0 - - def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): - """Sends an outgoing packet.""" - pout = r.DNSIncoming(out.packets()[0]) - nonlocal nbr_answers - for answer in pout.answers: - nbr_answers += 1 - if not answer.ttl > expected_ttl / 2: - unexpected_ttl.set() - - got_query.set() - got_query.clear() - - old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) - - # patch the zeroconf send - # patch the zeroconf current_time_millis - # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch.object( - _services_browser, "current_time_millis", current_time_millis - ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): - service_added = Event() - service_removed = Event() - - browser = ServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - - zeroconf_registrar = Zeroconf(interfaces=['127.0.0.1']) - desc = {'path': '/~paulsm/'} - info = ServiceInfo( - type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] - ) - zeroconf_registrar.register_service(info) - - try: - service_added.wait(1) - assert service_added.is_set() - - # Test that we receive queries containing answers only if the remaining TTL - # is greater than half the original TTL - sleep_count = 0 - test_iterations = 50 - - while nbr_answers < test_iterations: - # Increase simulated time shift by 1/4 of the TTL in seconds - time_offset += expected_ttl / 4 - zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed) - sleep_count += 1 - got_query.wait(0.5) - # Prevent the test running indefinitely in an error condition - assert sleep_count < test_iterations * 4 - assert not unexpected_ttl.is_set() - - # Don't remove service, allow close() to cleanup - - finally: - zeroconf_registrar.close() - service_removed.wait(1) - assert service_removed.is_set() - browser.cancel() - zeroconf_browser.close() - - def test_legacy_record_update_listener(): """Test a RecordUpdateListener that does not implement update_records.""" diff --git a/tests/test_aio.py b/tests/test_aio.py index 4523ca1e..0d587709 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -7,16 +7,18 @@ import asyncio import logging import socket +import time import threading import unittest.mock import pytest from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes -from zeroconf import Zeroconf +from zeroconf import DNSIncoming, ServiceStateChange, Zeroconf, const from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf._services import ServiceListener +import zeroconf._services.browser as _services_browser from zeroconf._services.info import ServiceInfo from zeroconf._utils.time import current_time_millis @@ -657,3 +659,96 @@ def update_service(self, zc, type_, name) -> None: await browser.async_cancel() await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_integration(): + service_added = asyncio.Event() + service_removed = asyncio.Event() + unexpected_ttl = asyncio.Event() + got_query = asyncio.Event() + + type_ = "_http._tcp.local." + registration_name = "xxxyyy.%s" % type_ + + def on_service_state_change(zeroconf, service_type, state_change, name): + if name == registration_name: + if state_change is ServiceStateChange.Added: + service_added.set() + elif state_change is ServiceStateChange.Removed: + service_removed.set() + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_browser = aiozc.zeroconf + await zeroconf_browser.async_wait_for_start() + + # we are going to patch the zeroconf send to check packet sizes + old_send = zeroconf_browser.async_send + + time_offset = 0.0 + + def current_time_millis(): + """Current system time in milliseconds""" + return (time.time() * 1000) + (time_offset * 1000) + + expected_ttl = const._DNS_HOST_TTL + nbr_answers = 0 + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): + """Sends an outgoing packet.""" + pout = DNSIncoming(out.packets()[0]) + nonlocal nbr_answers + for answer in pout.answers: + nbr_answers += 1 + if not answer.ttl > expected_ttl / 2: + unexpected_ttl.set() + + got_query.set() + got_query.clear() + + old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) + + # patch the zeroconf send + # patch the zeroconf current_time_millis + # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL + with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch( + "zeroconf._services.browser.current_time_millis", current_time_millis + ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): + service_added = asyncio.Event() + service_removed = asyncio.Event() + + browser = AsyncServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) + + aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + task = await aio_zeroconf_registrar.async_register_service(info) + await task + + try: + await asyncio.wait_for(service_added.wait(), 1) + assert service_added.is_set() + + # Test that we receive queries containing answers only if the remaining TTL + # is greater than half the original TTL + sleep_count = 0 + test_iterations = 50 + + while nbr_answers < test_iterations: + # Increase simulated time shift by 1/4 of the TTL in seconds + time_offset += expected_ttl / 4 + browser.query_scheduler.set_schedule_changed() + sleep_count += 1 + await asyncio.wait_for(got_query.wait(), 0.5) + # Prevent the test running indefinitely in an error condition + assert sleep_count < test_iterations * 4 + assert not unexpected_ttl.is_set() + # Don't remove service, allow close() to cleanup + finally: + await aio_zeroconf_registrar.async_close() + await asyncio.wait_for(service_removed.wait(), 1) + assert service_removed.is_set() + await browser.async_cancel() + await aiozc.async_close() From 10f4a7f8d607d09673be56e5709912403503d86b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 20:30:22 -1000 Subject: [PATCH 455/608] Disable duplicate question suppression for test_integration (#830) - This test waits until we get 50 known answers. It would sometimes fail because it could not ask enough unsuppressed questions in the allowed time. --- tests/test_aio.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 0d587709..0df5d297 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -662,7 +662,10 @@ def update_service(self, zc, type_, name) -> None: @pytest.mark.asyncio -async def test_integration(): +# Disable duplicate question suppression for this test as it works +# by asking the same question over and over +@unittest.mock.patch("zeroconf._core.QuestionHistory.suppresses", return_value=False) +async def test_integration(suppresses_mock): service_added = asyncio.Event() service_removed = asyncio.Event() unexpected_ttl = asyncio.Event() From 8230e3f40da5d2d152942725d67d5f8c0b8c647b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 21:33:05 -1000 Subject: [PATCH 456/608] Show 20 slowest tests on each run (#832) --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 6602d808..10e2960e 100644 --- a/Makefile +++ b/Makefile @@ -41,10 +41,10 @@ mypy: mypy --no-warn-redundant-casts --no-warn-unused-ignores examples/*.py zeroconf test: - pytest --timeout=60 -v tests + pytest --durations=20 --timeout=60 -v tests test_coverage: - pytest --timeout=60 -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests + pytest --durations=20 --timeout=60 -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf From 4039b0b755a3d0fe15e4cb1a7cb1592c35e048e1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 21:33:32 -1000 Subject: [PATCH 457/608] Annotate test failures on github (#831) --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9c6f98f1..132482b5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,6 +25,8 @@ jobs: run: | pip install --upgrade -r requirements-dev.txt pip install . + - name: Install pytest-github-actions-annotate-failures plugin + run: pip install pytest-github-actions-annotate-failures - name: Run tests run: make ci - name: Report coverage to Codecov From 0bf4f7537a042a00d9d3f815afcdf7ebe29d9f53 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 23:07:12 -1000 Subject: [PATCH 458/608] Cache dependency installs in CI (#833) --- .github/workflows/ci.yml | 45 ++++++++++++++++++++++++++-------------- Makefile | 3 ++- requirements-dev.txt | 2 +- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 132482b5..c19b6682 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,19 +15,34 @@ jobs: matrix: os: [ubuntu-latest, macos-latest, windows-latest] python-version: [3.6, 3.7, 3.8, 3.9, pypy3] + include: + - os: ubuntu-latest + venvcmd: . env/bin/activate + - os: macos-latest + venvcmd: . env/bin/activate + - os: windows-latest + venvcmd: env\Scripts\Activate.ps1 steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install --upgrade -r requirements-dev.txt - pip install . - - name: Install pytest-github-actions-annotate-failures plugin - run: pip install pytest-github-actions-annotate-failures - - name: Run tests - run: make ci - - name: Report coverage to Codecov - uses: codecov/codecov-action@v1 + - uses: actions/checkout@v2 + - uses: actions/cache@v2 + id: cache + with: + path: env + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }}-${{ hashFiles('**/Makefile') }} + restore-keys: | + ${{ runner.os }}-pip- + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m venv env + ${{ matrix.venvcmd }} + pip install --upgrade -r requirements-dev.txt pytest-github-actions-annotate-failures + - name: Run tests + run: | + ${{ matrix.venvcmd }} + make ci + - name: Report coverage to Codecov + uses: codecov/codecov-action@v1 diff --git a/Makefile b/Makefile index 10e2960e..d0335d0f 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ endif virtualenv: ./env/requirements.built env: - virtualenv env + python -m venv env ./env/requirements.built: env requirements-dev.txt ./env/bin/pip install -r requirements-dev.txt @@ -48,3 +48,4 @@ test_coverage: autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf + diff --git a/requirements-dev.txt b/requirements-dev.txt index eef93254..dc2f21de 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ autopep8 black;implementation_name=="cpython" coveralls coverage -# Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 +# Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 - is fixed flake8>=3.6.0 flake8-import-order ifaddr From 540c65218eb9d1aedc88a3d3724af97f39ccb88e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 24 Jun 2021 23:20:25 -1000 Subject: [PATCH 459/608] Wait for startup in test_integration (#834) --- tests/test_aio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_aio.py b/tests/test_aio.py index 0df5d297..e1e6de2c 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -723,6 +723,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): browser = AsyncServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + await aio_zeroconf_registrar.zeroconf.async_wait_for_start() + desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] From 0b1abbc8f2b09235cfd44e5586024c7b82dc5289 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Jun 2021 09:58:07 -1000 Subject: [PATCH 460/608] Ensure coverage.xml is written for codecov (#837) --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index d0335d0f..378970c4 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test: pytest --durations=20 --timeout=60 -v tests test_coverage: - pytest --durations=20 --timeout=60 -v --cov=zeroconf --cov-branch --cov-report html --cov-report term-missing tests + pytest --durations=20 --timeout=60 -v --cov=zeroconf --cov-branch --cov-report xml --cov-report html --cov-report term-missing tests autopep8: autopep8 --max-line-length=$(MAX_LINE_LENGTH) -i setup.py examples zeroconf From 7297f3ef71c9984296c3e28539ce7a4b42f04a05 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Jun 2021 10:09:49 -1000 Subject: [PATCH 461/608] Make multipacket known answer suppression per interface (#836) - The suppression was happening per instance of Zeroconf instead of per interface. Since the same network can be seen on multiple interfaces (usually and wifi and ethernet), this would confuse the multi-packet known answer supression since it was not expecting to get the same data more than once Fixes #835 --- tests/test_core.py | 91 ++++++++++++++++++++++----------------------- zeroconf/_core.py | 93 +++++++++++++++++++++++++++++----------------- 2 files changed, 104 insertions(+), 80 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 9f2412f0..e9a42012 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -17,10 +17,10 @@ from typing import cast import zeroconf as r -from zeroconf import _core, _protocol, const, ServiceBrowser, Zeroconf, current_time_millis +from zeroconf import _core, _protocol, const, Zeroconf, current_time_millis from zeroconf.aio import AsyncZeroconf -from . import has_working_ipv6, _clear_cache, _inject_response +from . import has_working_ipv6, _clear_cache, _inject_response, _wait_for_start log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -37,6 +37,13 @@ def teardown_module(): log.setLevel(original_logging_level) +def threadsafe_query(zc, protocol, *args): + async def make_query(): + protocol.handle_query_or_defer(*args) + + asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() + + # This test uses asyncio because it needs to access the cache directly # which is not threadsafe @pytest.mark.asyncio @@ -408,6 +415,7 @@ def test_sending_unicast(): def test_tc_bit_defers(): zc = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zc) type_ = "_tcbitdefer._tcp.local." name = "knownname" name2 = "knownname2" @@ -435,12 +443,7 @@ def test_tc_bit_defers(): zc.registry.add(info2) zc.registry.add(info3) - def threadsafe_query(*args): - async def make_query(): - zc.handle_query(*args) - - asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() - + protocol = zc.engine.protocols[0] now = r.current_time_millis() _clear_cache(zc) @@ -459,30 +462,30 @@ async def make_query(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert source_ip not in zc._deferred - assert source_ip not in zc._timers + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert source_ip not in protocol._deferred + assert source_ip not in protocol._timers # unregister zc.unregister_service(info) @@ -491,6 +494,7 @@ async def make_query(): def test_tc_bit_defers_last_response_missing(): zc = Zeroconf(interfaces=['127.0.0.1']) + _wait_for_start(zc) type_ = "_knowndefer._tcp.local." name = "knownname" name2 = "knownname2" @@ -518,12 +522,7 @@ def test_tc_bit_defers_last_response_missing(): zc.registry.add(info2) zc.registry.add(info3) - def threadsafe_query(*args): - async def make_query(): - zc.handle_query(*args) - - asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result() - + protocol = zc.engine.protocols[0] now = r.current_time_millis() _clear_cache(zc) source_ip = '203.0.113.12' @@ -542,45 +541,45 @@ async def make_query(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - timer1 = zc._timers[source_ip] + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + timer1 = protocol._timers[source_ip] next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - timer2 = zc._timers[source_ip] + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + timer2 = protocol._timers[source_ip] if sys.version_info >= (3, 7): assert timer1.cancelled() assert timer2 != timer1 # Send the same packet again to similar multi interfaces - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers - timer3 = zc._timers[source_ip] + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + timer3 = protocol._timers[source_ip] if sys.version_info >= (3, 7): assert not timer3.cancelled() assert timer3 == timer2 next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(next_packet, source_ip, const._MDNS_PORT) - assert zc._deferred[source_ip] == expected_deferred - assert source_ip in zc._timers - timer4 = zc._timers[source_ip] + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + assert protocol._deferred[source_ip] == expected_deferred + assert source_ip in protocol._timers + timer4 = protocol._timers[source_ip] if sys.version_info >= (3, 7): assert timer3.cancelled() assert timer4 != timer3 for _ in range(8): time.sleep(0.1) - if source_ip not in zc._timers and source_ip not in zc._deferred: + if source_ip not in protocol._timers and source_ip not in protocol._deferred: break - assert source_ip not in zc._deferred - assert source_ip not in zc._timers + assert source_ip not in protocol._deferred + assert source_ip not in protocol._timers # unregister zc.registry.remove(info) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 8237eb16..d5461e6d 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -85,6 +85,7 @@ def __init__( ) -> None: self.loop: Optional[asyncio.AbstractEventLoop] = None self.zc = zeroconf + self.protocols: List[AsyncListener] = [] self.readers: List[asyncio.DatagramTransport] = [] self.senders: List[asyncio.DatagramTransport] = [] self._listen_socket = listen_socket @@ -127,7 +128,8 @@ async def _async_create_endpoints(self) -> None: sender_sockets.append(s) for s in reader_sockets: - transport, _ = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) + transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) + self.protocols.append(cast(AsyncListener, protocol)) self.readers.append(cast(asyncio.DatagramTransport, transport)) if s in sender_sockets: self.senders.append(cast(asyncio.DatagramTransport, transport)) @@ -185,6 +187,10 @@ def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.data: Optional[bytes] = None self.transport: Optional[asyncio.DatagramTransport] = None + + self._deferred: Dict[str, List[DNSIncoming]] = {} + self._timers: Dict[str, asyncio.TimerHandle] = {} + super().__init__() def datagram_received( @@ -254,7 +260,49 @@ def datagram_received( self.zc.handle_response(msg) return - self.zc.handle_query(msg, addr, port, v6_flow_scope) + self.handle_query_or_defer(msg, addr, port, v6_flow_scope) + + def handle_query_or_defer( + self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + ) -> None: + """Deal with incoming query packets. Provides a response if + possible.""" + if not msg.truncated: + self._respond_query(msg, addr, port, v6_flow_scope) + return + + deferred = self._deferred.setdefault(addr, []) + # If we get the same packet we ignore it + for incoming in reversed(deferred): + if incoming.data == msg.data: + return + deferred.append(msg) + delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) + assert self.zc.loop is not None + self._cancel_any_timers_for_addr(addr) + self._timers[addr] = self.zc.loop.call_later( + delay, self._respond_query, None, addr, port, v6_flow_scope + ) + + def _cancel_any_timers_for_addr(self, addr: str) -> None: + """Cancel any future truncated packet timers for the address.""" + if addr in self._timers: + self._timers.pop(addr).cancel() + + def _respond_query( + self, + msg: Optional[DNSIncoming], + addr: str, + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + """Respond to a query and reassemble any truncated deferred packets.""" + self._cancel_any_timers_for_addr(addr) + packets = self._deferred.pop(addr, []) + if msg: + packets.append(msg) + + self.zc.handle_assembled_query(packets, addr, port, v6_flow_scope) def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" @@ -317,9 +365,6 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None - self._deferred: Dict[str, List[DNSIncoming]] = {} - self._timers: Dict[str, asyncio.TimerHandle] = {} - self.start() def start(self) -> None: @@ -605,41 +650,21 @@ def handle_response(self, msg: DNSIncoming) -> None: are held in the cache, and listeners are notified.""" self.record_manager.async_updates_from_response(msg) - def handle_query( - self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () - ) -> None: - """Deal with incoming query packets. Provides a response if - possible.""" - if not msg.truncated: - self._respond_query(msg, addr, port) - return - - deferred = self._deferred.setdefault(addr, []) - # If we get the same packet on another iterface we ignore it - for incoming in reversed(deferred): - if incoming.data == msg.data: - return - deferred.append(msg) - delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) - assert self.loop is not None - if addr in self._timers: - self._timers.pop(addr).cancel() - self._timers[addr] = self.loop.call_later(delay, self._respond_query, None, addr, port, v6_flow_scope) - - def _respond_query( + def handle_assembled_query( self, - msg: Optional[DNSIncoming], + packets: List[DNSIncoming], addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: - """Respond to a query and reassemble any truncated deferred packets.""" - if addr in self._timers: - self._timers.pop(addr).cancel() - packets = self._deferred.pop(addr, []) - if msg: - packets.append(msg) + """Respond to a (re)assembled query. + If the protocol recieved packets with the TC bit set, it will + wait a bit for the rest of the packets and only call + handle_assembled_query once it has a complete set of packets + or the timer expires. If the TC bit is not set, a single + packet will be in packets. + """ unicast_out, multicast_out = self.query_handler.async_response(packets, addr, port) if unicast_out: self.async_send(unicast_out, addr, port, v6_flow_scope) From 3fdd8349553c160586fb6831c9466410f19a3308 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Jun 2021 10:11:40 -1000 Subject: [PATCH 462/608] Adjust restore key for CI cache (#838) --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c19b6682..54815381 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,7 +30,7 @@ jobs: path: env key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }}-${{ hashFiles('**/Makefile') }} restore-keys: | - ${{ runner.os }}-pip- + ${{ runner.os }}-pip-${{ matrix.python-version }}- - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: From 937be522a42830b27326b5253d49003b57998bc9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Jun 2021 10:40:54 -1000 Subject: [PATCH 463/608] Skip dependencies install in CI on cache hit (#839) There is no need to reinstall dependencies in the CI when we have a cache hit. --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 54815381..1e7484b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,6 +36,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' run: | python -m venv env ${{ matrix.venvcmd }} From 7fb11bfc03c06cbe9ed5a4303b3e632d69665bb1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 25 Jun 2021 21:50:30 -1000 Subject: [PATCH 464/608] Limit duplicate packet suppression to 1s intervals (#841) - Only suppress duplicate packets that happen within the same second. Legitimate queriers will retry the question if they are suppressed. The limit was reduced to one second to be in line with rfc6762: To protect the network against excessive packet flooding due to software bugs or malicious attack, a Multicast DNS responder MUST NOT (except in the one special case of answering probe queries) multicast a record on a given interface until at least one second has elapsed since the last time that record was multicast on that particular --- tests/test_core.py | 18 ++++++++++++++++++ zeroconf/_core.py | 17 +++++++++++++---- zeroconf/_protocol.py | 4 ++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index e9a42012..8e4a17cf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -696,3 +696,21 @@ def test_guard_against_oversized_packets(): ) zc.close() + + +def test_guard_against_duplicate_packets(): + """Ensure we do not process duplicate packets. + These packets can quickly overwhelm the system. + """ + zc = Zeroconf(interfaces=['127.0.0.1']) + listener = _core.AsyncListener(zc) + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis() + 1000) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is False + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis()) is True + assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False + assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False + zc.close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index d5461e6d..a70d43e7 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -186,6 +186,7 @@ class AsyncListener(asyncio.Protocol, QuietLogger): def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.data: Optional[bytes] = None + self.last_time: float = 0 self.transport: Optional[asyncio.DatagramTransport] = None self._deferred: Dict[str, List[DNSIncoming]] = {} @@ -193,6 +194,14 @@ def __init__(self, zc: 'Zeroconf') -> None: super().__init__() + def suppress_duplicate_packet(self, data: bytes, now: float) -> bool: + """Suppress duplicate packet if the last one was the same in the last second.""" + if self.data == data and (now - 1000) < self.last_time: + return True + self.data = data + self.last_time = now + return False + def datagram_received( self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] ) -> None: @@ -210,7 +219,9 @@ def datagram_received( else: return - if self.data == data: + now = current_time_millis() + if self.suppress_duplicate_packet(data, now): + # Guard against duplicate packets log.debug( 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', addr, @@ -221,8 +232,6 @@ def datagram_received( ) return - self.data = data - if len(data) > _MAX_MSG_ABSOLUTE: # Guard against oversized packets to ensure bad implementations cannot overwhelm # the system. @@ -234,7 +243,7 @@ def datagram_received( ) return - msg = DNSIncoming(data, scope) + msg = DNSIncoming(data, scope, now) if msg.valid: log.debug( 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 50bbca28..79f483de 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -79,7 +79,7 @@ class DNSIncoming(DNSMessage, QuietLogger): """Object representation of an incoming DNS packet""" - def __init__(self, data: bytes, scope_id: Optional[int] = None) -> None: + def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None: """Constructor from string holding bytes of packet""" super().__init__(0) self.offset = 0 @@ -92,7 +92,7 @@ def __init__(self, data: bytes, scope_id: Optional[int] = None) -> None: self.num_authorities = 0 self.num_additionals = 0 self.valid = False - self.now = current_time_millis() + self.now = now or current_time_millis() self.scope_id = scope_id try: From ecd9c941810e4b413b20dc55929b3ae1a7e57b27 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 07:27:22 -1000 Subject: [PATCH 465/608] Fix ineffective patching on PyPy (#842) - Use patch in all places so its easier to find where we need to clean up --- tests/services/test_browser.py | 22 +++++++++---------- tests/services/test_info.py | 9 ++++---- tests/test_aio.py | 40 ++++++++++++++++++++++++---------- tests/test_core.py | 5 +++-- tests/test_init.py | 3 ++- tests/utils/test_aio.py | 4 ++-- 6 files changed, 52 insertions(+), 31 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index eb25c7e2..7a5f5df4 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -11,6 +11,7 @@ import os import unittest from threading import Event +from unittest.mock import patch import pytest @@ -382,7 +383,7 @@ def _mock_get_expiration_time(self, percent): return self.created + (percent * self.ttl * 10) # Set an expire time that will force a refresh - with unittest.mock.patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): + with patch("zeroconf.DNSRecord.get_expiration_time", new=_mock_get_expiration_time): _inject_response( zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0], 120), @@ -430,8 +431,7 @@ def _mock_get_expiration_time(self, percent): zeroconf.close() -@unittest.mock.patch("zeroconf._core.QuestionHistory.suppresses", return_value=False) -def test_backoff(suppresses_mock): +def test_backoff(): got_query = Event() type_ = "_http._tcp.local." @@ -457,11 +457,11 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): # patch the zeroconf send # patch the zeroconf current_time_millis # patch the backoff limit to prevent test running forever - with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch.object( - _services_browser, "current_time_millis", current_time_millis - ), unittest.mock.patch.object( + with patch.object(zeroconf_browser, "async_send", send), patch.object( + zeroconf_browser.question_history, "suppresses", return_value=False + ), patch.object(_services_browser, "current_time_millis", current_time_millis), patch.object( _services_browser, "_BROWSER_BACKOFF_LIMIT", 10 - ), unittest.mock.patch.object( + ), patch.object( _services_browser, "_FIRST_QUERY_DELAY_RANDOM_INTERVAL", (0, 0) ): # dummy service callback @@ -525,7 +525,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + with patch.object(zeroconf_browser, "async_send", send): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass @@ -564,7 +564,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + with patch.object(zeroconf_browser, "async_send", send): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass @@ -597,7 +597,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + with patch.object(zeroconf_browser, "async_send", send): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass @@ -631,7 +631,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf_browser, "async_send", send): + with patch.object(zeroconf_browser, "async_send", send): # dummy service callback def on_service_state_change(zeroconf, service_type, state_change, name): pass diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 348d623d..adca1a53 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -9,6 +9,7 @@ import threading import os import unittest +from unittest.mock import patch from threading import Event from typing import List @@ -218,7 +219,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): send_event.set() # patch the zeroconf send - with unittest.mock.patch.object(zc, "async_send", send): + with patch.object(zc, "async_send", send): def mock_incoming_msg(records) -> r.DNSIncoming: @@ -363,7 +364,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): send_event.set() # patch the zeroconf send - with unittest.mock.patch.object(zc, "async_send", send): + with patch.object(zc, "async_send", send): def mock_incoming_msg(records) -> r.DNSIncoming: @@ -675,7 +676,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf, "async_send", send): + with patch.object(zeroconf, "async_send", send): zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QU) assert first_outgoing.questions[0].unicast == True zeroconf.close() @@ -699,7 +700,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): old_send(out, addr=addr, port=port) # patch the zeroconf send - with unittest.mock.patch.object(zeroconf, "async_send", send): + with patch.object(zeroconf, "async_send", send): zeroconf.get_service_info(f"name.{type_}", type_, 500) assert first_outgoing.questions[0].unicast == False zeroconf.close() diff --git a/tests/test_aio.py b/tests/test_aio.py index e1e6de2c..5fb41a7a 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -9,7 +9,8 @@ import socket import time import threading -import unittest.mock +from unittest.mock import patch + import pytest @@ -411,7 +412,7 @@ async def test_service_info_async_request() -> None: _clear_cache(aiozc.zeroconf) # Generating the race condition is almost impossible # without patching since its a TOCTOU race - with unittest.mock.patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): + with patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): await aiosinfo.async_request(aiozc.zeroconf, 3000) assert aiosinfo is not None assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] @@ -662,10 +663,7 @@ def update_service(self, zc, type_, name) -> None: @pytest.mark.asyncio -# Disable duplicate question suppression for this test as it works -# by asking the same question over and over -@unittest.mock.patch("zeroconf._core.QuestionHistory.suppresses", return_value=False) -async def test_integration(suppresses_mock): +async def test_integration(): service_added = asyncio.Event() service_removed = asyncio.Event() unexpected_ttl = asyncio.Event() @@ -711,20 +709,40 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) + assert len(zeroconf_browser.engine.protocols) == 2 + + aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_registrar = aio_zeroconf_registrar.zeroconf + await aio_zeroconf_registrar.zeroconf.async_wait_for_start() + + assert len(zeroconf_registrar.engine.protocols) == 2 # patch the zeroconf send # patch the zeroconf current_time_millis # patch the backoff limit to ensure we always get one query every 1/4 of the DNS TTL - with unittest.mock.patch.object(zeroconf_browser, "async_send", send), unittest.mock.patch( + # Disable duplicate question suppression and duplicate packet suppression for this test as it works + # by asking the same question over and over + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_browser.question_history, "suppresses", return_value=False + ), patch.object( + zeroconf_browser, "async_send", send + ), patch( "zeroconf._services.browser.current_time_millis", current_time_millis - ), unittest.mock.patch.object(_services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4)): + ), patch.object( + _services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4) + ): service_added = asyncio.Event() service_removed = asyncio.Event() browser = AsyncServiceBrowser(zeroconf_browser, type_, [on_service_state_change]) - aio_zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) - await aio_zeroconf_registrar.zeroconf.async_wait_for_start() - desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] diff --git a/tests/test_core.py b/tests/test_core.py index 8e4a17cf..85571ddd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -15,6 +15,7 @@ import unittest import unittest.mock from typing import cast +from unittest.mock import patch import zeroconf as r from zeroconf import _core, _protocol, const, Zeroconf, current_time_millis @@ -48,7 +49,7 @@ async def make_query(): # which is not threadsafe @pytest.mark.asyncio async def test_reaper(): - with unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): + with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): assert _core._CACHE_CLEANUP_INTERVAL == 10 aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) zeroconf = aiozc.zeroconf @@ -652,7 +653,7 @@ def test_guard_against_oversized_packets(): ) # We are patching to generate an oversized packet - with unittest.mock.patch.object(_protocol, "_MAX_MSG_ABSOLUTE", 100000), unittest.mock.patch.object( + with patch.object(_protocol, "_MAX_MSG_ABSOLUTE", 100000), patch.object( _protocol, "_MAX_MSG_TYPICAL", 100000 ): over_sized_packet = generated.packets()[0] diff --git a/tests/test_init.py b/tests/test_init.py index 0cc3baf8..0383af1a 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -10,6 +10,7 @@ import unittest import unittest.mock from typing import Optional # noqa # used in type hints +from unittest.mock import patch import zeroconf as r from zeroconf import DNSOutgoing, ServiceBrowser, ServiceInfo, Zeroconf, const @@ -90,7 +91,7 @@ def test_large_packet_exception_log_handling(self): # instantiate a zeroconf instance zc = Zeroconf(interfaces=['127.0.0.1']) - with unittest.mock.patch('zeroconf._logger.log.warning') as mocked_log_warn, unittest.mock.patch( + with patch('zeroconf._logger.log.warning') as mocked_log_warn, patch( 'zeroconf._logger.log.debug' ) as mocked_log_debug: # now that we have a long packet in our possession, let's verify the diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index b0fa8dbc..52a23dea 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -6,7 +6,7 @@ import asyncio import contextlib -import unittest.mock +from unittest.mock import patch import pytest @@ -23,7 +23,7 @@ async def test_async_get_all_tasks() -> None: await aioutils._async_get_all_tasks(aioutils.get_running_loop()) if not hasattr(asyncio, 'all_tasks'): return - with unittest.mock.patch("zeroconf._utils.aio.asyncio.all_tasks", side_effect=RuntimeError): + with patch("zeroconf._utils.aio.asyncio.all_tasks", side_effect=RuntimeError): await aioutils._async_get_all_tasks(aioutils.get_running_loop()) From 688c5184dce67e5af857c138639ced4bdcec1e57 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 07:31:20 -1000 Subject: [PATCH 466/608] Use AAAA records instead of A records in test_integration_with_listener_ipv6 (#843) --- tests/services/test_types.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 7a07085f..8b38317b 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -99,6 +99,7 @@ def test_integration_with_listener_ipv6(self): type_ = "_test-listenv6ip-type._tcp.local." name = "xxxyyy" registration_name = "%s.%s" % (name, type_) + addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com zeroconf_registrar = Zeroconf(ip_version=r.IPVersion.V6Only) desc = {'path': '/~paulsm/'} @@ -110,7 +111,7 @@ def test_integration_with_listener_ipv6(self): 0, desc, "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], + addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) zeroconf_registrar.registry.add(info) try: From dd86f2f9fee4bbaebce956b330c1837a6e9c6c99 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 07:35:12 -1000 Subject: [PATCH 467/608] Increase timeout in test_integration (#844) - The github macOS runners tend to be a bit loaded and these sometimes fail because of it --- tests/test_aio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 5fb41a7a..a146cf96 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -764,7 +764,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): time_offset += expected_ttl / 4 browser.query_scheduler.set_schedule_changed() sleep_count += 1 - await asyncio.wait_for(got_query.wait(), 0.5) + await asyncio.wait_for(got_query.wait(), 1) # Prevent the test running indefinitely in an error condition assert sleep_count < test_iterations * 4 assert not unexpected_ttl.is_set() From 72502c303a1a889cf84906b8764fd941a840e6d3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 07:52:14 -1000 Subject: [PATCH 468/608] Update changelog (#845) --- README.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.rst b/README.rst index 61c05e98..00b4f35e 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,30 @@ See examples directory for more. Changelog ========= +0.32.0 Release Candidate 2 +========================== + +* Limit duplicate packet suppression to 1s intervals (#841) @bdraco + + Only suppress duplicate packets that happen within the same + second. Legitimate queriers will retry the question if they + are suppressed. The limit was reduced to one second to be + in line with rfc6762 + +* Make multipacket known answer suppression per interface (#836) @bdraco + + The suppression was happening per instance of Zeroconf instead + of per interface. Since the same network can be seen on multiple + interfaces (usually and wifi and ethernet), this would confuse the + multi-packet known answer supression since it was not expecting + to get the same data more than once + + +0.32.0 Release Candidate 1 +========================== + +No changes + 0.32.0 Beta 6 ============= From 182c68ff11ba381444a708e17560e920ae1849ef Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 16:44:09 -1000 Subject: [PATCH 469/608] Fix thread safety in handlers test (#847) --- tests/test_handlers.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ea29c528..8fe2c56d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1022,8 +1022,10 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.Recor await aiozc.async_close() -def test_questions_query_handler_populates_the_question_history_from_qm_questions(): - zc = Zeroconf(interfaces=['127.0.0.1']) +@pytest.mark.asyncio +async def test_questions_query_handler_populates_the_question_history_from_qm_questions(): + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf now = current_time_millis() _clear_cache(zc) @@ -1044,11 +1046,13 @@ def test_questions_query_handler_populates_the_question_history_from_qm_question assert multicast_out is None assert zc.question_history.suppresses(question, now, set([known_answer])) - zc.close() + await aiozc.async_close() -def test_questions_query_handler_does_not_put_qu_questions_in_history(): - zc = Zeroconf(interfaces=['127.0.0.1']) +@pytest.mark.asyncio +async def test_questions_query_handler_does_not_put_qu_questions_in_history(): + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf now = current_time_millis() _clear_cache(zc) @@ -1069,17 +1073,19 @@ def test_questions_query_handler_does_not_put_qu_questions_in_history(): assert multicast_out is None assert not zc.question_history.suppresses(question, now, set([known_answer])) - zc.close() + await aiozc.async_close() -def test_guard_against_low_ptr_ttl(): +@pytest.mark.asyncio +async def test_guard_against_low_ptr_ttl(): """Ensure we enforce a minimum for PTR record ttls to avoid excessive refresh queries from ServiceBrowsers. Some poorly designed IoT devices can set excessively low PTR TTLs would will cause ServiceBrowsers to flood the network with excessive refresh queries. """ - zc = Zeroconf(interfaces=['127.0.0.1']) + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf # Apple uses a 15s minimum TTL, however we do not have the same # level of rate limit and safe guards so we use 1/4 of the recommended value answer_with_low_ttl = r.DNSPointer( @@ -1116,4 +1122,4 @@ def test_guard_against_low_ptr_ttl(): incoming_answer_normal = zc.cache.async_get_unique(answer_with_normal_ttl) assert incoming_answer_normal.ttl == const._DNS_OTHER_TTL assert zc.cache.async_get_unique(good_bye_answer) is None - zc.close() + await aiozc.async_close() From 9f71e5b7364d4a23492cafe4f49a5c2acda4178d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 26 Jun 2021 17:03:33 -1000 Subject: [PATCH 470/608] Fix spurious failures in ZeroconfServiceTypes tests (#848) - These tests ran the same test twice in 0.5s and would trigger the duplicate packet suppression. Rather then making them run longer, we can disable the suppression for the test. --- tests/services/test_types.py | 61 ++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/tests/services/test_types.py b/tests/services/test_types.py index 8b38317b..d14a8b25 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -9,6 +9,7 @@ import unittest import socket import sys +from unittest.mock import patch import zeroconf as r from zeroconf import Zeroconf, ServiceInfo, ZeroconfServiceTypes @@ -51,11 +52,16 @@ def test_integration_with_listener(self): ) zeroconf_registrar.registry.add(info) try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types finally: zeroconf_registrar.close() @@ -83,11 +89,16 @@ def test_integration_with_listener_v6_records(self): ) zeroconf_registrar.registry.add(info) try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types finally: zeroconf_registrar.close() @@ -115,11 +126,16 @@ def test_integration_with_listener_ipv6(self): ) zeroconf_registrar.registry.add(info) try: - service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) - assert type_ in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert type_ in service_types + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) + assert type_ in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert type_ in service_types finally: zeroconf_registrar.close() @@ -146,11 +162,16 @@ def test_integration_with_subtype_and_listener(self): ) zeroconf_registrar.registry.add(info) try: - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) - assert discovery_type in service_types - _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) - assert discovery_type in service_types + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + assert discovery_type in service_types + _clear_cache(zeroconf_registrar) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + assert discovery_type in service_types finally: zeroconf_registrar.close() From a8c16231881de43adedbedbc3f1ea707c0b457f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 07:08:54 -1000 Subject: [PATCH 471/608] Switch ServiceBrowser query scheduling to use call_later instead of a loop (#849) - Simplifies scheduling as there is no more need to sleep in a loop as we now schedule future callbacks with call_later - Simplifies cancelation as there is no more coroutine to cancel, only a timer handle We no longer have to handle the canceled error and cleaning up the awaitable - Solves the infrequent test failures in test_backoff and test_integration --- tests/services/test_browser.py | 44 ++--------------- tests/test_aio.py | 9 ++-- zeroconf/_services/browser.py | 86 +++++++++++++++------------------- zeroconf/aio.py | 2 - 4 files changed, 46 insertions(+), 95 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 7a5f5df4..36f459c7 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -496,7 +496,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): else: assert not got_query.is_set() time_offset += initial_query_interval - zeroconf_browser.loop.call_soon_threadsafe(browser.query_scheduler.set_schedule_changed) + zeroconf_browser.loop.call_soon_threadsafe(browser.schedule_changed) finally: browser.cancel() @@ -984,8 +984,8 @@ async def test_query_scheduler(): # Test query interval is increasing assert query_scheduler.millis_to_wait(now - 1) == 1 - assert query_scheduler.millis_to_wait(now) is None - assert query_scheduler.millis_to_wait(now + 1) is None + assert query_scheduler.millis_to_wait(now) is 0 + assert query_scheduler.millis_to_wait(now + 1) is 0 assert set(query_scheduler.process_ready_types(now)) == types_ assert set(query_scheduler.process_ready_types(now)) == set() @@ -1018,41 +1018,3 @@ async def test_query_scheduler(): assert set(query_scheduler.process_ready_types(now + delay * 20)) == set() assert set(query_scheduler.process_ready_types(now + delay * 31)) == set(["_http._tcp.local."]) - - -@pytest.mark.asyncio -async def test_query_scheduler_triggers_async_wait_ready_on_reschedule(): - """Test that a reschedule wakes up the async_wait_ready.""" - delay = const._BROWSER_TIME - types_ = set(["_hap._tcp.local.", "_http._tcp.local."]) - query_scheduler = _services_browser.QueryScheduler(types_, delay, (0, 0)) - - now = current_time_millis() - query_scheduler.start(now) - assert set(query_scheduler.process_ready_types(now)) == types_ - assert query_scheduler.millis_to_wait(now) == delay - - task = asyncio.ensure_future(query_scheduler.async_wait_ready(now)) - await asyncio.sleep(0) # Start the task - await asyncio.sleep(0) # Make sure its waiting - assert not task.done() - assert query_scheduler.millis_to_wait(now + 1) == delay - 1 - query_scheduler.reschedule_type("_hap._tcp.local.", now + 1) - assert query_scheduler.millis_to_wait(now + 1) is None - await asyncio.wait_for(task, timeout=0.1) - assert task.done() - - task2 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000)) - assert set(query_scheduler.process_ready_types(now + 1)) == set(["_hap._tcp.local."]) - assert not task2.done() - assert query_scheduler.millis_to_wait(now + 2) == delay - 2 - query_scheduler.reschedule_type("_hap._tcp.local.", now + 2) - assert query_scheduler.millis_to_wait(now + 2) is None - await asyncio.wait_for(task2, timeout=0.1) - assert task2.done() - assert set(query_scheduler.process_ready_types(now + 10000)) == types_ - assert query_scheduler.millis_to_wait(now + 10000) == delay * 2 - - task3 = asyncio.ensure_future(query_scheduler.async_wait_ready(now + 10000)) - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(task3, timeout=0.1) diff --git a/tests/test_aio.py b/tests/test_aio.py index a146cf96..41e0e83a 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -688,7 +688,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): time_offset = 0.0 - def current_time_millis(): + def _new_current_time_millis(): """Current system time in milliseconds""" return (time.time() * 1000) + (time_offset * 1000) @@ -705,7 +705,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): unexpected_ttl.set() got_query.set() - got_query.clear() old_send(out, addr=addr, port=port, v6_flow_scope=v6_flow_scope) @@ -734,7 +733,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): ), patch.object( zeroconf_browser, "async_send", send ), patch( - "zeroconf._services.browser.current_time_millis", current_time_millis + "zeroconf._services.browser.current_time_millis", _new_current_time_millis ), patch.object( _services_browser, "_BROWSER_BACKOFF_LIMIT", int(expected_ttl / 4) ): @@ -762,9 +761,11 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): while nbr_answers < test_iterations: # Increase simulated time shift by 1/4 of the TTL in seconds time_offset += expected_ttl / 4 - browser.query_scheduler.set_schedule_changed() + now = _new_current_time_millis() + browser.reschedule_type(type_, now) sleep_count += 1 await asyncio.wait_for(got_query.wait(), 1) + got_query.clear() # Prevent the test running indefinitely in an error condition assert sleep_count < test_iterations * 4 assert not unexpected_ttl.is_set() diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index a7abca4f..bc368edb 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -21,7 +21,6 @@ """ import asyncio -import contextlib import queue import random import threading @@ -39,7 +38,7 @@ SignalRegistrationInterface, ) from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.aio import get_best_available_queue, wait_event_or_timeout +from .._utils.aio import get_best_available_queue from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( @@ -221,25 +220,17 @@ def _generate_first_next_time(self, now: float) -> None: next_time = now + delay self._next_time = {check_type_: next_time for check_type_ in self._types} - def millis_to_wait(self, now: float) -> Optional[float]: + def millis_to_wait(self, now: float) -> float: """Returns the number of milliseconds to wait for the next event.""" # Wait for the type has the smallest next time next_time = min(self._next_time.values()) - return None if next_time <= now else next_time - now + return 0 if next_time <= now else next_time - now def reschedule_type(self, type_: str, next_time: float) -> None: """Reschedule the query for a type to happen sooner.""" if next_time >= self._next_time[type_]: return - self._next_time[type_] = next_time - self.set_schedule_changed() - - def set_schedule_changed(self) -> None: - """Set the event to unblock async_wait_ready to make sure the adjusted next time is seen.""" - assert self._schedule_changed_event is not None - self._schedule_changed_event.set() - self._schedule_changed_event.clear() def process_ready_types(self, now: float) -> List[str]: """Generate a list of ready types that is due and schedule the next time.""" @@ -258,13 +249,6 @@ def process_ready_types(self, now: float) -> List[str]: return ready_types - async def async_wait_ready(self, now: float) -> None: - """Wait for at least one query to be ready.""" - timeout = self.millis_to_wait(now) - if timeout: - assert self._schedule_changed_event is not None - await wait_event_or_timeout(self._schedule_changed_event, timeout=millis_to_seconds(timeout)) - class _ServiceBrowserBase(RecordUpdateListener): """Base class for ServiceBrowser.""" @@ -302,7 +286,6 @@ def __init__( for check_type_ in self.types: # Will generate BadTypeInNameException on a bad name service_type_name(check_type_, strict=False) - self._browser_task: Optional[asyncio.Task] = None self.zc = zc self.addr = addr self.port = port @@ -313,6 +296,8 @@ def __init__( self.query_scheduler = QueryScheduler(self.types, delay, _FIRST_QUERY_DELAY_RANDOM_INTERVAL) self.queue: Optional[queue.Queue] = None self.done = False + self._first_request: bool = True + self._next_send_timer: Optional[asyncio.TimerHandle] = None if hasattr(handlers, 'add_service'): listener = cast('ServiceListener', handlers) @@ -335,7 +320,7 @@ def _async_start(self) -> None: self.query_scheduler.start(current_time_millis()) self.zc.async_add_listener(self, [DNSQuestion(type_, _TYPE_PTR, _CLASS_IN) for type_ in self.types]) # Only start queries after the listener is installed - self._browser_task = cast(asyncio.Task, asyncio.ensure_future(self.async_browser_task())) + asyncio.ensure_future(self._async_start_query_sender()) @property def service_state_changed(self) -> SignalRegistrationInterface: @@ -378,9 +363,7 @@ def _async_process_record_update( elif expired: self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: - self.query_scheduler.reschedule_type( - record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT) - ) + self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)) return # If its expired or already exists in the cache it cannot be updated. @@ -448,6 +431,7 @@ def _fire_service_state_changed_event(self, event: Tuple[Tuple[str, str], Servic def _async_cancel(self) -> None: """Cancel the browser.""" self.done = True + self._cancel_send_timer() self.zc.async_remove_listener(self) def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: @@ -464,28 +448,40 @@ def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: question_type = DNSQuestionType.QU if not self.question_type and first_request else self.question_type return generate_service_query(self.zc, now, ready_types, self.multicast, question_type) - async def async_browser_task(self) -> None: - """Run the browser task.""" + async def _async_start_query_sender(self) -> None: + """Start scheduling queries.""" await self.zc.async_wait_for_start() - first_request = True - while True: - await self.query_scheduler.async_wait_ready(current_time_millis()) - outs = self._generate_ready_queries(first_request) - if not outs: - continue + self._async_send_ready_queries_schedule_next() + + def _cancel_send_timer(self) -> None: + """Cancel the next send.""" + if self._next_send_timer: + self._next_send_timer.cancel() - first_request = False + def reschedule_type(self, type_: str, next_time: float) -> None: + """Reschedule a type to be refreshed in the future.""" + self.query_scheduler.reschedule_type(type_, next_time) + self.schedule_changed() + + def schedule_changed(self) -> None: + """Called when the schedule has changed.""" + self._cancel_send_timer() + self._async_send_ready_queries_schedule_next() + + def _async_send_ready_queries_schedule_next(self) -> None: + """Send any ready queries and scheule the next time.""" + if self.done or self.zc.done: + return + + outs = self._generate_ready_queries(self._first_request) + if outs: + self._first_request = False for out in outs: self.zc.async_send(out, addr=self.addr, port=self.port) - async def _async_cancel_browser(self) -> None: - """Cancel the browser.""" - assert self._browser_task is not None - self._browser_task.cancel() - browser_task = self._browser_task - self._browser_task = None - with contextlib.suppress(asyncio.CancelledError): - await browser_task + assert self.zc.loop is not None + delay = millis_to_seconds(self.query_scheduler.millis_to_wait(current_time_millis())) + self._next_send_timer = self.zc.loop.call_later(delay, self._async_send_ready_queries_schedule_next) class ServiceBrowser(_ServiceBrowserBase, threading.Thread): @@ -523,18 +519,12 @@ def __init__( getattr(self, 'native_id', self.ident), ) - def _async_cancel_soon(self) -> None: - """Cancel the browser from the event loop.""" - self._async_cancel() - if self._browser_task: - asyncio.ensure_future(self._async_cancel_browser()) - def cancel(self) -> None: """Cancel the browser.""" assert self.zc.loop is not None assert self.queue is not None self.queue.put(None) - self.zc.loop.call_soon_threadsafe(self._async_cancel_soon) + self.zc.loop.call_soon_threadsafe(self._async_cancel) self.join() def run(self) -> None: diff --git a/zeroconf/aio.py b/zeroconf/aio.py index 985a440b..67ff1c12 100644 --- a/zeroconf/aio.py +++ b/zeroconf/aio.py @@ -91,8 +91,6 @@ def __init__( async def async_cancel(self) -> None: """Cancel the browser.""" self._async_cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._async_cancel_browser() class AsyncZeroconfServiceTypes(ZeroconfServiceTypes): From 8c9d1d8964d9226d5d3ac38bec908e930954b369 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 07:11:18 -1000 Subject: [PATCH 472/608] Update changelog (#850) --- README.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.rst b/README.rst index 00b4f35e..d69369b9 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,20 @@ See examples directory for more. Changelog ========= + +0.32.0 Release Candidate 3 +========================== + +* Switch ServiceBrowser query scheduling to use call_later instead of a loop (#849) @bdraco + + Simplifies scheduling as there is no more need to sleep in a loop as + we now schedule future callbacks with call_later + + Simplifies cancelation as there is no more coroutine to cancel, only a timer handle + We no longer have to handle the canceled error and cleaning up the awaitable + + Solves the infrequent test failures in test_backoff and test_integration + 0.32.0 Release Candidate 2 ========================== From 76e0b05ca9c601bd638817bf68ca8d981f1d65f8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 10:04:17 -1000 Subject: [PATCH 473/608] Make ServiceInfo first question QU (#852) - We want an immediate response when making a request with ServiceInfo by asking a QU question, most responders will not delay the response and respond right away to our question. This also improves compatibility with split networks as we may not have been able to see the response otherwise. If the responder has not multicast the record recently it may still choose to do so in addition to responding via unicast - Reduces traffic when there are multiple zeroconf instances running on the network running ServiceBrowsers - If we don't get an answer on the first try, we ask a QM question in the event we can't receive a unicast response for some reason - This change puts ServiceInfo inline with ServiceBrowser which also asks the first question as QU since ServiceInfo is commonly called from ServiceBrowser callbacks closes #851 --- tests/services/test_info.py | 6 ++--- tests/test_aio.py | 47 +++++++++++++++++++++++++++++++++++++ zeroconf/_services/info.py | 6 ++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index adca1a53..37f98aa1 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -682,8 +682,8 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): zeroconf.close() -def test_asking_qm_questions_are_default(): - """Verify default is QM questions.""" +def test_asking_qm_questions(): + """Verify explictly asking QM questions.""" type_ = "_quservice._tcp.local." zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) @@ -701,6 +701,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): # patch the zeroconf send with patch.object(zeroconf, "async_send", send): - zeroconf.get_service_info(f"name.{type_}", type_, 500) + zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QM) assert first_outgoing.questions[0].unicast == False zeroconf.close() diff --git a/tests/test_aio.py b/tests/test_aio.py index 41e0e83a..f22bf966 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -776,3 +776,50 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()): assert service_removed.is_set() await browser.async_cancel() await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_info_asking_default_is_asking_qm_questions_after_the_first_qu(): + """Verify the service info first question is QU and subsequent ones are QM questions.""" + type_ = "_quservice._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf_info = aiozc.zeroconf + + name = "xxxyyy" + registration_name = "%s.%s" % (name, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + zeroconf_info.registry.add(info) + + # we are going to patch the zeroconf send to check query transmission + old_send = zeroconf_info.async_send + + first_outgoing = None + second_outgoing = None + + def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal first_outgoing + nonlocal second_outgoing + if out.questions: + if first_outgoing is not None and second_outgoing is None: + second_outgoing = out + if first_outgoing is None: + first_outgoing = out + old_send(out, addr=addr, port=port) + + # patch the zeroconf send + with patch.object(zeroconf_info, "async_send", send): + aiosinfo = AsyncServiceInfo(type_, registration_name) + # Patch _is_complete so we send multiple times + with patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): + await aiosinfo.async_request(aiozc.zeroconf, 1200) + try: + assert first_outgoing.questions[0].unicast == True + assert second_outgoing.questions[0].unicast == False + finally: + await aiozc.async_close() diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 4365d6ef..9d1c37f3 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -438,6 +438,7 @@ async def async_request( if self.load_from_cache(zc): return True + first_request = True now = current_time_millis() delay = _LISTENER_TIME next_ = now @@ -449,7 +450,10 @@ async def async_request( if last <= now: return False if next_ <= now: - out = self.generate_request_query(zc, now, question_type) + out = self.generate_request_query( + zc, now, question_type or DNSQuestionType.QU if first_request else DNSQuestionType.QM + ) + first_request = False if not out.questions: return self.load_from_cache(zc) zc.async_send(out) From 0cd876f5a42699aeb0176380ba4cca4d8a536df3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 10:04:27 -1000 Subject: [PATCH 474/608] Speed up test_verify_name_change_with_lots_of_names under PyPy (#853) fixes #840 --- tests/__init__.py | 12 +++++++++--- tests/test_init.py | 14 +++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index d77140fd..2671fe62 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -23,7 +23,7 @@ import asyncio import socket from functools import lru_cache - +from typing import List import ifaddr @@ -31,16 +31,22 @@ from zeroconf import DNSIncoming, Zeroconf -def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: +def _inject_responses(zc: Zeroconf, msgs: List[DNSIncoming]) -> None: """Inject a DNSIncoming response.""" assert zc.loop is not None async def _wait_for_response(): - zc.handle_response(msg) + for msg in msgs: + zc.handle_response(msg) asyncio.run_coroutine_threadsafe(_wait_for_response(), zc.loop).result() +def _inject_response(zc: Zeroconf, msg: DNSIncoming) -> None: + """Inject a DNSIncoming response.""" + _inject_responses(zc, [msg]) + + def _wait_for_start(zc: Zeroconf) -> None: """Wait for all sockets to be up and running.""" assert zc.loop is not None diff --git a/tests/test_init.py b/tests/test_init.py index 0383af1a..5005a75d 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -9,13 +9,12 @@ import time import unittest import unittest.mock -from typing import Optional # noqa # used in type hints from unittest.mock import patch import zeroconf as r -from zeroconf import DNSOutgoing, ServiceBrowser, ServiceInfo, Zeroconf, const +from zeroconf import ServiceInfo, Zeroconf, const -from . import _inject_response +from . import _inject_responses log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -162,14 +161,16 @@ def verify_name_change(self, zc, type_, name, number_hosts): def generate_many_hosts(self, zc, type_, name, number_hosts): block_size = 25 number_hosts = int(((number_hosts - 1) / block_size + 1)) * block_size + out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) for i in range(1, number_hosts + 1): next_name = name if i == 1 else '%s-%d' % (name, i) - self.generate_host(zc, next_name, type_) + self.generate_host(out, next_name, type_) + + _inject_responses(zc, [r.DNSIncoming(packet) for packet in out.packets()]) @staticmethod - def generate_host(zc, host_name, type_): + def generate_host(out, host_name, type_): name = '.'.join((host_name, type_)) - out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) out.add_answer_at_time( r.DNSPointer(type_, const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, name), 0 ) @@ -186,4 +187,3 @@ def generate_host(zc, host_name, type_): ), 0, ) - _inject_response(zc, r.DNSIncoming(out.packets()[0])) From 03411f35d82752d5d2633a67db132a011098d9e6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 10:59:27 -1000 Subject: [PATCH 475/608] Only run linters on Linux in CI (#855) - The github MacOS and Windows runners are slower and will have the same results as the Linux runners so there is no need to wait for them. closes #854 --- .github/workflows/ci.yml | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e7484b3..9d4f9a35 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,26 +24,46 @@ jobs: venvcmd: env\Scripts\Activate.ps1 steps: - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} - uses: actions/cache@v2 id: cache with: path: env - key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/requirements-dev.txt') }}-${{ hashFiles('**/Makefile') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }}-${{ hashFiles('**/requirements-dev.txt') }} restore-keys: | - ${{ runner.os }}-pip-${{ matrix.python-version }}- - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/Makefile') }} - name: Install dependencies if: steps.cache.outputs.cache-hit != 'true' run: | python -m venv env ${{ matrix.venvcmd }} pip install --upgrade -r requirements-dev.txt pytest-github-actions-annotate-failures + - name: Run flake8 + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make flake8 + - name: Run mypy + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make mypy + - name: Run black_check + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make black_check + - name: Run pylint + if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} + run: | + ${{ matrix.venvcmd }} + make pylint - name: Run tests run: | ${{ matrix.venvcmd }} - make ci + make test_coverage - name: Report coverage to Codecov uses: codecov/codecov-action@v1 From cb2e237b6f1af0a83bc7352464562cdb7bbcac14 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 11:11:39 -1000 Subject: [PATCH 476/608] Update changelog (#856) --- README.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.rst b/README.rst index d69369b9..186e4430 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,27 @@ See examples directory for more. Changelog ========= +0.32.0 Release Candidate 4 +========================== + +Make ServiceInfo first question QU (#852) @bdraco + + We want an immediate response when making a request with ServiceInfo + by asking a QU question, most responders will not delay the response + and respond right away to our question. This also improves compatibility + with split networks as we may not have been able to see the response + otherwise. If the responder has not multicast the record recently + it may still choose to do so in addition to responding via unicast + + Reduces traffic when there are multiple zeroconf instances running + on the network running ServiceBrowsers + + If we don't get an answer on the first try, we ask a QM question + in the event we can't receive a unicast response for some reason + + This change puts ServiceInfo inline with ServiceBrowser which + also asks the first question as QU since ServiceInfo is commonly + called from ServiceBrowser callbacks 0.32.0 Release Candidate 3 ========================== From 59247f1c44b485bf51d4a8d3e3966b9faf40cf82 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 11:13:10 -1000 Subject: [PATCH 477/608] Fix changelog formatting (#857) --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 186e4430..47e3c584 100644 --- a/README.rst +++ b/README.rst @@ -143,7 +143,7 @@ Changelog 0.32.0 Release Candidate 4 ========================== -Make ServiceInfo first question QU (#852) @bdraco +* Make ServiceInfo first question QU (#852) @bdraco We want an immediate response when making a request with ServiceInfo by asking a QU question, most responders will not delay the response From 3eb7be95fd6cd4960f96f29aa72fc45347c57b6e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 16:04:45 -1000 Subject: [PATCH 478/608] Cleanup coverage data (#858) --- .coveragerc | 1 + zeroconf/__init__.py | 4 ++-- zeroconf/const.py | 5 ++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.coveragerc b/.coveragerc index 56ef8a32..7648cf0d 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,3 +2,4 @@ exclude_lines = pragma: no cover if TYPE_CHECKING: + if sys.version_info diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 263043aa..e3bb987f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -97,8 +97,8 @@ "ZeroconfServiceTypes", ] -if sys.version_info <= (3, 6): - raise ImportError( +if sys.version_info <= (3, 6): # pragma: no cover + raise ImportError( # pragma: no cover ''' Python version > 3.6 required for python-zeroconf. If you need support for Python 2 or Python 3.3-3.4 please use version 19.1 diff --git a/zeroconf/const.py b/zeroconf/const.py index 8107a2af..0f26d80a 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -20,6 +20,7 @@ USA """ +import contextlib import re import socket @@ -39,10 +40,8 @@ _MDNS_ADDR = '224.0.0.251' _MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) _MDNS_ADDR6 = 'ff02::fb' -try: +with contextlib.suppress(OSError): # can't use AF_INET6, IPv6 is disabled _MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) -except OSError: # can't use AF_INET6, IPv6 is disabled - pass _MDNS_PORT = 5353 _DNS_PORT = 53 _DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 From 57cccc4dcbdc9df52672297968ccb55054122049 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 16:20:34 -1000 Subject: [PATCH 479/608] Make a dispatch dict for ServiceStateChange listeners (#859) --- zeroconf/_services/browser.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index bc368edb..5f2dbd31 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -58,6 +58,12 @@ # https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 _FIRST_QUERY_DELAY_RANDOM_INTERVAL = (20, 120) # ms +_ON_CHANGE_DISPATCH = { + ServiceStateChange.Added: "add_service", + ServiceStateChange.Removed: "remove_service", + ServiceStateChange.Updated: "update_service", +} + if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 from .._core import Zeroconf # pylint: disable=cyclic-import @@ -159,25 +165,18 @@ def generate_service_query( def _service_state_changed_from_listener(listener: ServiceListener) -> Callable[..., None]: """Generate a service_state_changed handlers from a listener.""" + assert listener is not None + if not hasattr(listener, 'update_service'): + warnings.warn( + "%r has no update_service method. Provide one (it can be empty if you " + "don't care about the updates), it'll become mandatory." % (listener,), + FutureWarning, + ) def on_change( zeroconf: 'Zeroconf', service_type: str, name: str, state_change: ServiceStateChange ) -> None: - assert listener is not None - args = (zeroconf, service_type, name) - if state_change is ServiceStateChange.Added: - listener.add_service(*args) - elif state_change is ServiceStateChange.Removed: - listener.remove_service(*args) - elif state_change is ServiceStateChange.Updated: - if hasattr(listener, 'update_service'): - listener.update_service(*args) - else: - warnings.warn( - "%r has no update_service method. Provide one (it can be empty if you " - "don't care about the updates), it'll become mandatory." % (listener,), - FutureWarning, - ) + getattr(listener, _ON_CHANGE_DISPATCH[state_change])(zeroconf, service_type, name) return on_change From af83c766c2ae72bd23184c6f6300e4d620c7b3e8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 16:34:02 -1000 Subject: [PATCH 480/608] Add unit coverage for shutdown_loop (#860) --- tests/utils/test_aio.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index 52a23dea..fd33234f 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -6,6 +6,8 @@ import asyncio import contextlib +import threading +import time from unittest.mock import patch import pytest @@ -56,3 +58,28 @@ async def _async_wait_or_timeout(): task.cancel() with contextlib.suppress(asyncio.CancelledError): await task + + +def test_shutdown_loop() -> None: + """Test shutting down an event loop.""" + loop = None + loop_thread_ready = threading.Event() + + def _run_loop() -> None: + nonlocal loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop_thread_ready.set() + loop.run_forever() + + loop_thread = threading.Thread(target=_run_loop, daemon=True) + loop_thread.start() + loop_thread_ready.wait() + + aioutils.shutdown_loop(loop) + for _ in range(5): + if not loop.is_running(): + break + time.sleep(0.05) + + assert loop.is_running() is False From f5368692d7907e440ca81f0acee9744f79dbae80 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 16:50:35 -1000 Subject: [PATCH 481/608] Remove unreachable code in AsyncListener.datagram_received (#863) --- zeroconf/_core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a70d43e7..53510573 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -211,13 +211,11 @@ def datagram_received( # https://github.com/python/mypy/issues/1178 addr, port = addrs # type: ignore scope = None - elif len(addrs) == 4: + else: # https://github.com/python/mypy/issues/1178 addr, port, flow, scope = addrs # type: ignore log.debug('IPv6 scope_id %d associated to the receiving interface', scope) v6_flow_scope = (flow, scope) - else: - return now = current_time_millis() if self.suppress_duplicate_packet(data, now): From c516919064687551299f23e23bf0797888020041 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 16:50:44 -1000 Subject: [PATCH 482/608] Ensure protocol and sending errors are logged once (#862) --- zeroconf/_core.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 53510573..b5e971b4 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -22,7 +22,6 @@ import asyncio import contextlib -import errno import itertools import random import socket @@ -313,6 +312,10 @@ def _respond_query( def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" + assert self.transport is not None + self.log_warning_once( + 'Error with socket %d: %s', self.transport.get_extra_info('socket').fileno(), exc + ) def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) @@ -714,24 +717,13 @@ def async_send( if self._GLOBAL_DONE: return s = transport.get_extra_info('socket') - try: - if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR - elif not can_send_to(s, addr): - continue - else: - real_addr = addr - transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) - except OSError as exc: - if exc.errno == errno.ENETUNREACH and s.family == socket.AF_INET6: - # with IPv6 we don't have a reliable way to determine if an interface actually has - # IPV6 support, so we have to try and ignore errors. - continue - # on send errors, log the exception and keep going - self.log_exception_warning('Error sending through socket %d', s.fileno()) - except Exception: # pylint: disable=broad-except # TODO stop catching all Exceptions - # on send errors, log the exception and keep going - self.log_exception_warning('Error sending through socket %d', s.fileno()) + if addr is None: + real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR + elif not can_send_to(s, addr): + continue + else: + real_addr = addr + transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) def _close(self) -> None: """Set global done and remove all service listeners.""" From c64064ad3b38a40775637c0fd8877d9d00d2d537 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 17:16:39 -1000 Subject: [PATCH 483/608] Update changelog (#864) --- README.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.rst b/README.rst index 47e3c584..5342489f 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,15 @@ See examples directory for more. Changelog ========= +0.32.0 Release Candidate 5 +========================== + +* Ensure protocol and sending errors are logged once (#862) @bdraco + +* Remove unreachable code in AsyncListener.datagram_received (#863) @bdraco + +* Make a dispatch dict for ServiceStateChange listeners (#859) @bdraco + 0.32.0 Release Candidate 4 ========================== From 6ef65fc7cafc3d4089a2b943da224c6cb027b4b0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 21:05:29 -1000 Subject: [PATCH 484/608] Add test coverage for duplicate properties in a TXT record (#865) --- tests/services/test_info.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 37f98aa1..6a5ae428 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -441,6 +441,41 @@ def get_service_info_helper(zc, type, name): zc.remove_all_service_listeners() zc.close() + def test_service_info_duplicate_properties_txt_records(self): + """Verify the first property is always used when there are duplicates in a txt record.""" + + zc = r.Zeroconf(interfaces=['127.0.0.1']) + desc = {'path': '/~paulsm/'} + service_name = 'name._type._tcp.local.' + service_type = '_type._tcp.local.' + service_server = 'ash-1.local.' + service_address = socket.inet_aton("10.0.1.2") + ttl = 120 + now = r.current_time_millis() + info = ServiceInfo( + service_type, service_name, 22, 0, 0, desc, service_server, addresses=[service_address] + ) + info.async_update_records( + zc, + now, + [ + r.RecordUpdate( + r.DNSText( + service_name, + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + ttl, + b'\x04ff=0\x04ci=2\x04sf=0\x0bsh=6fLM5A==\x04dd=0\x04jl=2\x04qq=0\x0brr=6fLM5A==\x04ci=3', + ), + None, + ) + ], + ) + assert info.properties[b"dd"] == b"0" + assert info.properties[b"jl"] == b"2" + assert info.properties[b"ci"] == b"2" + zc.close() + def test_multiple_addresses(): type_ = "_http._tcp.local." From dcf18c8a32652c6aa70af180b6a5261f4277faa9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 21:20:32 -1000 Subject: [PATCH 485/608] Add test coverage to ensure ServiceBrowser ignores unrelated updates (#866) --- tests/test_aio.py | 92 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index f22bf966..9da099fa 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -15,7 +15,16 @@ import pytest from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes -from zeroconf import DNSIncoming, ServiceStateChange, Zeroconf, const +from zeroconf import ( + DNSIncoming, + DNSOutgoing, + DNSPointer, + DNSService, + DNSAddress, + ServiceStateChange, + Zeroconf, + const, +) from zeroconf.const import _LISTENER_TIME from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf._services import ServiceListener @@ -823,3 +832,84 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): assert second_outgoing.questions[0].unicast == False finally: await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_service_browser_ignores_unrelated_updates(): + """Test that the ServiceBrowser ignores unrelated updates.""" + + # instantiate a zeroconf instance + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + type_ = "_veryuniqueone._tcp.local." + registration_name = "xxxyyy.%s" % type_ + callbacks = [] + + class MyServiceListener(ServiceListener): + def add_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("add", type_, name)) + + def remove_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("remove", type_, name)) + + def update_service(self, zc, type_, name) -> None: + nonlocal callbacks + if name == registration_name: + callbacks.append(("update", type_, name)) + + listener = MyServiceListener() + + desc = {'path': '/~paulsm/'} + address_parsed = "10.0.1.2" + address = socket.inet_aton(address_parsed) + info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) + zc.cache.async_add_records( + [info.dns_pointer(), info.dns_service(), *info.dns_addresses(), info.dns_text()] + ) + + browser = AsyncServiceBrowser(zc, type_, None, listener) + + generated = DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time( + DNSPointer( + "_unrelated._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + const._DNS_OTHER_TTL, + "zoom._unrelated._tcp.local.", + ), + 0, + ) + generated.add_answer_at_time( + DNSAddress( + "zoom._unrelated._tcp.local.", const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b"1234" + ), + 0, + ) + generated.add_answer_at_time( + DNSService( + "zoom._unrelated._tcp.local.", + const._TYPE_SRV, + const._CLASS_IN, + const._DNS_HOST_TTL, + 0, + 0, + 81, + 'unrelated.local.', + ), + 0, + ) + + zc.handle_response(DNSIncoming(generated.packets()[0])) + + await browser.async_cancel() + await asyncio.sleep(0) + + assert callbacks == [ + ('add', type_, registration_name), + ] + await aiozc.async_close() From 22ff6b56d7b6531d2af5c50dca66fd2be2b276f4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 21:47:17 -1000 Subject: [PATCH 486/608] Break apart new_socket to be testable (#867) --- tests/utils/test_net.py | 40 +++++++++++++++++++++++++++++++++ zeroconf/_utils/net.py | 49 +++++++++++++++++++++++++---------------- 2 files changed, 70 insertions(+), 19 deletions(-) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 7890f381..16c2b485 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -8,6 +8,7 @@ import errno import ifaddr import pytest +import socket import unittest from zeroconf._utils import net as netutils @@ -99,3 +100,42 @@ def test_autodetect_ip_version(): assert r.autodetect_ip_version([]) is r.IPVersion.V4Only assert r.autodetect_ip_version(["::1", "1.2.3.4"]) is r.IPVersion.All assert r.autodetect_ip_version(["::1"]) is r.IPVersion.V6Only + + +def test_disable_ipv6_only_or_raise(): + """Test that IPV6_V6ONLY failing logs a nice error message and still raises.""" + errors_logged = [] + + def _log_error(*args): + nonlocal errors_logged + errors_logged.append(args) + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with pytest.raises(OSError), patch.object(netutils.log, "error", _log_error), patch( + "socket.socket.setsockopt", side_effect=OSError + ): + netutils.disable_ipv6_only_or_raise(sock) + + assert ( + errors_logged[0][0] + == 'Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6' + ) + + +@pytest.mark.skipif(not hasattr(socket, 'SO_REUSEPORT'), reason="System does not have SO_REUSEPORT") +def test_set_so_reuseport_if_available_is_present(): + """Test that setting socket.SO_REUSEPORT only OSError errno.ENOPROTOOPT is trapped.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError): + netutils.set_so_reuseport_if_available(sock) + + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)): + netutils.set_so_reuseport_if_available(sock) + + +@pytest.mark.skipif(hasattr(socket, 'SO_REUSEPORT'), reason="System has SO_REUSEPORT") +def test_set_so_reuseport_if_available_not_present(): + """Test that we do not try to set SO_REUSEPORT if it is not present.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + with patch("socket.socket.setsockopt", side_effect=OSError): + netutils.set_so_reuseport_if_available(sock) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 80a4377b..b30c828a 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -161,6 +161,34 @@ def normalize_interface_choice( return result +def disable_ipv6_only_or_raise(s: socket.socket) -> None: + """Make V6 sockets work for both V4 and V6 (required for Windows).""" + try: + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + except OSError: + log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') + raise + + +def set_so_reuseport_if_available(s: socket.socket) -> None: + """Set SO_REUSEADDR on a socket if available.""" + # SO_REUSEADDR should be equivalent to SO_REUSEPORT for + # multicast UDP sockets (p 731, "TCP/IP Illustrated, + # Volume 2"), but some BSD-derived systems require + # SO_REUSEPORT to be specified explicitly. Also, not all + # versions of Python have SO_REUSEPORT available. + # Catch OSError and socket.error for kernel versions <3.9 because lacking + # SO_REUSEPORT support. + if not hasattr(socket, 'SO_REUSEPORT'): + return + + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member + except OSError as err: + if err.errno != errno.ENOPROTOOPT: + raise + + def new_socket( # pylint: disable=too-many-branches bind_addr: Union[Tuple[str], Tuple[str, int, int]], port: int = _MDNS_PORT, @@ -180,28 +208,11 @@ def new_socket( # pylint: disable=too-many-branches s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) if ip_version == IPVersion.All: - # make V6 sockets work for both V4 and V6 (required for Windows) - try: - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_V6ONLY, False) - except OSError: - log.error('Support for dual V4-V6 sockets is not present, use IPVersion.V4 or IPVersion.V6') - raise + disable_ipv6_only_or_raise(s) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # SO_REUSEADDR should be equivalent to SO_REUSEPORT for - # multicast UDP sockets (p 731, "TCP/IP Illustrated, - # Volume 2"), but some BSD-derived systems require - # SO_REUSEPORT to be specified explicitly. Also, not all - # versions of Python have SO_REUSEPORT available. - # Catch OSError and socket.error for kernel versions <3.9 because lacking - # SO_REUSEPORT support. - if hasattr(socket, 'SO_REUSEPORT'): - try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) # pylint: disable=no-member - except OSError as err: - if err.errno != errno.ENOPROTOOPT: - raise + set_so_reuseport_if_available(s) if port == _MDNS_PORT: ttl = struct.pack(b'B', 255) From 4ed903698b10f434cfbbe601998f27c10d2fb9db Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 22:43:19 -1000 Subject: [PATCH 487/608] Fix deadlock when event loop is shutdown during service registration (#869) --- tests/test_core.py | 32 ++++++++++++++++++++++++++++++++ tests/utils/test_aio.py | 14 ++++++++++++++ zeroconf/_core.py | 9 +++++++-- zeroconf/_services/info.py | 4 ++-- zeroconf/_utils/aio.py | 17 ++++++++++++----- 5 files changed, 67 insertions(+), 9 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 85571ddd..2a3f368b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,6 +12,7 @@ import socket import sys import time +import threading import unittest import unittest.mock from typing import cast @@ -715,3 +716,34 @@ def test_guard_against_duplicate_packets(): assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False zc.close() + + +def test_shutdown_while_register_in_process(): + """Test we can shutdown while registering a service in another thread.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_homeassistant._tcp.local." + name = "MyTestHome" + info_service = r.ServiceInfo( + type_, + '%s.%s' % (name, type_), + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + def _background_register(): + zc.register_service(info_service) + + bgthread = threading.Thread(target=_background_register, daemon=True) + bgthread.start() + time.sleep(0.3) + + zc.close() + bgthread.join() diff --git a/tests/utils/test_aio.py b/tests/utils/test_aio.py index fd33234f..524fd973 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_aio.py @@ -64,6 +64,7 @@ def test_shutdown_loop() -> None: """Test shutting down an event loop.""" loop = None loop_thread_ready = threading.Event() + runcoro_thread_ready = threading.Event() def _run_loop() -> None: nonlocal loop @@ -76,6 +77,18 @@ def _run_loop() -> None: loop_thread.start() loop_thread_ready.wait() + async def _still_running(): + await asyncio.sleep(5) + + def _run_coro() -> None: + runcoro_thread_ready.set() + asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1) + + runcoro_thread = threading.Thread(target=_run_coro, daemon=True) + runcoro_thread.start() + runcoro_thread_ready.wait() + + time.sleep(0.1) aioutils.shutdown_loop(loop) for _ in range(5): if not loop.is_running(): @@ -83,3 +96,4 @@ def _run_loop() -> None: time.sleep(0.05) assert loop.is_running() is False + runcoro_thread.join() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index b5e971b4..ef4d4d70 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -21,6 +21,7 @@ """ import asyncio +import concurrent.futures import contextlib import itertools import random @@ -71,6 +72,7 @@ ) _TC_DELAY_RANDOM_INTERVAL = (400, 500) +_CLOSE_TIMEOUT = 3 class AsyncEngine: @@ -170,7 +172,7 @@ def close(self) -> None: return if not self.loop.is_running(): return - asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result() + asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result(_CLOSE_TIMEOUT) class AsyncListener(asyncio.Protocol, QuietLogger): @@ -416,7 +418,10 @@ def listeners(self) -> List[RecordUpdateListener]: def wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" assert self.loop is not None - asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result() + with contextlib.suppress(concurrent.futures.TimeoutError): + asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result( + millis_to_seconds(timeout) + ) async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 9d1c37f3..52dabd2b 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -37,7 +37,7 @@ _is_v6_address, ) from .._utils.struct import int2byte -from .._utils.time import current_time_millis +from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( _CLASS_IN, _CLASS_UNIQUE, @@ -427,7 +427,7 @@ def request( raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") return asyncio.run_coroutine_threadsafe( self.async_request(zc, timeout, question_type), zc.loop - ).result() + ).result(millis_to_seconds(timeout) + 1) async def async_request( self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/aio.py index 7cc3b7fa..57c1fb18 100644 --- a/zeroconf/_utils/aio.py +++ b/zeroconf/_utils/aio.py @@ -25,6 +25,10 @@ import queue from typing import Any, List, Optional, Set, cast +_TASK_AWAIT_TIMEOUT = 1 +_GET_ALL_TASKS_TIMEOUT = 1 +_WAIT_FOR_LOOP_TASKS_TIMEOUT = 2 # Must be larger than _TASK_AWAIT_TIMEOUT + def get_best_available_queue() -> queue.Queue: """Create the best available queue type.""" @@ -73,16 +77,19 @@ async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio. async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: """Wait for the event loop thread we started to shutdown.""" - await asyncio.wait(wait_tasks, timeout=1) + await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" - pending_tasks = set(asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result()) - done_tasks = set(task for task in pending_tasks if not task.done()) - pending_tasks -= done_tasks + pending_tasks = set( + asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT) + ) + pending_tasks -= set(task for task in pending_tasks if task.done()) if pending_tasks: - asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result() + asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result( + _WAIT_FOR_LOOP_TASKS_TIMEOUT + ) loop.call_soon_threadsafe(loop.stop) From 972da99e4dd9d0fe1c1e0786da45d66fd43a717a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 27 Jun 2021 22:46:04 -1000 Subject: [PATCH 488/608] Update changelog (#870) --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index 5342489f..3943d9e2 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,13 @@ See examples directory for more. Changelog ========= +0.32.0 Release Candidate 6 +========================== + +* Fix deadlock when event loop is shutdown during service registration (#869) @bdraco + +* Break apart new_socket to be testable (#867) @bdraco + 0.32.0 Release Candidate 5 ========================== From 471bacd3200aa1216054c0e52b2e5842e9760aa0 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 15:49:08 -1000 Subject: [PATCH 489/608] Add coverage to ensure unrelated A records do not generate ServiceBrowser callbacks (#874) closes #871 --- tests/test_aio.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/test_aio.py b/tests/test_aio.py index 9da099fa..fb3f07ea 100644 --- a/tests/test_aio.py +++ b/tests/test_aio.py @@ -21,6 +21,7 @@ DNSPointer, DNSService, DNSAddress, + DNSText, ServiceStateChange, Zeroconf, const, @@ -868,7 +869,22 @@ def update_service(self, zc, type_, name) -> None: address = socket.inet_aton(address_parsed) info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]) zc.cache.async_add_records( - [info.dns_pointer(), info.dns_service(), *info.dns_addresses(), info.dns_text()] + [ + info.dns_pointer(), + info.dns_service(), + *info.dns_addresses(), + info.dns_text(), + DNSService( + "zoom._unrelated._tcp.local.", + const._TYPE_SRV, + const._CLASS_IN, + const._DNS_HOST_TTL, + 0, + 0, + 81, + 'unrelated.local.', + ), + ] ) browser = AsyncServiceBrowser(zc, type_, None, listener) @@ -885,21 +901,16 @@ def update_service(self, zc, type_, name) -> None: 0, ) generated.add_answer_at_time( - DNSAddress( - "zoom._unrelated._tcp.local.", const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b"1234" - ), + DNSAddress("unrelated.local.", const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b"1234"), 0, ) generated.add_answer_at_time( - DNSService( + DNSText( "zoom._unrelated._tcp.local.", - const._TYPE_SRV, - const._CLASS_IN, - const._DNS_HOST_TTL, - 0, - 0, - 81, - 'unrelated.local.', + const._TYPE_TXT, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + b"zoom", ), 0, ) From decd8a26aa8a89ceefcd9452fe562f2eeaa3fecb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 16:14:33 -1000 Subject: [PATCH 490/608] Fix flapping test test_integration_with_listener_class (#876) --- tests/test_services.py | 194 +++++++++++++++++++++-------------------- 1 file changed, 100 insertions(+), 94 deletions(-) diff --git a/tests/test_services.py b/tests/test_services.py index 12ad95ba..b1e2d890 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -10,6 +10,7 @@ import os import unittest from threading import Event +from unittest.mock import patch import pytest @@ -95,100 +96,105 @@ def update_service(self, zeroconf, type, name): ) zeroconf_registrar.register_service(info_service) - try: - service_added.wait(1) - assert service_added.is_set() - - # short pause to allow multicast timers to expire - time.sleep(3) - - # clear the answer cache to force query - _clear_cache(zeroconf_browser) - - cached_info = ServiceInfo(type_, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties == {} - - # get service info without answer cache - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_none'] is None - assert info.properties[b'prop_string'] == properties['prop_string'] - assert info.properties[b'prop_float'] == b'1.0' - assert info.properties[b'prop_blank'] == properties['prop_blank'] - assert info.properties[b'prop_true'] == b'1' - assert info.properties[b'prop_false'] == b'0' - assert info.addresses == addresses[:1] # no V6 by default - assert set(info.addresses_by_version(r.IPVersion.All)) == set(addresses) - - cached_info = ServiceInfo(type_, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - - # Populate the cache - zeroconf_browser.get_service_info(subtype, registration_name) - - # get service info with only the cache - cached_info = ServiceInfo(subtype, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_float'] == b'1.0' - - # get service info with only the cache with the lowercase name - cached_info = ServiceInfo(subtype, registration_name.lower()) - cached_info.load_from_cache(zeroconf_browser) - # Ensure uppercase output is preserved - assert cached_info.name == registration_name - assert cached_info.key == registration_name.lower() - assert cached_info.properties is not None - assert cached_info.properties[b'prop_float'] == b'1.0' - - info = zeroconf_browser.get_service_info(subtype, registration_name) - assert info is not None - assert info.properties is not None - assert info.properties[b'prop_none'] is None - - cached_info = ServiceInfo(subtype, registration_name.lower()) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_none'] is None - - # test TXT record update - sublistener = MySubListener() - zeroconf_browser.add_service_listener(registration_name, sublistener) - properties['prop_blank'] = b'an updated string' - desc.update(properties) - info_service = ServiceInfo( - subtype, - registration_name, - 80, - 0, - 0, - desc, - "ash-2.local.", - addresses=[socket.inet_aton("10.0.1.2")], - ) - zeroconf_registrar.update_service(info_service) - service_updated.wait(1) - assert service_updated.is_set() - - info = zeroconf_browser.get_service_info(type_, registration_name) - assert info is not None - assert info.properties[b'prop_blank'] == properties['prop_blank'] - - cached_info = ServiceInfo(subtype, registration_name) - cached_info.load_from_cache(zeroconf_browser) - assert cached_info.properties is not None - assert cached_info.properties[b'prop_blank'] == properties['prop_blank'] - - zeroconf_registrar.unregister_service(info_service) - service_removed.wait(1) - assert service_removed.is_set() - - finally: - zeroconf_registrar.close() - zeroconf_browser.remove_service_listener(listener) - zeroconf_browser.close() + with patch.object( + zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False + ), patch.object( + zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False + ): + try: + service_added.wait(1) + assert service_added.is_set() + + # short pause to allow multicast timers to expire + time.sleep(3) + + # clear the answer cache to force query + _clear_cache(zeroconf_browser) + + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties == {} + + # get service info without answer cache + info = zeroconf_browser.get_service_info(type_, registration_name) + assert info is not None + assert info.properties[b'prop_none'] is None + assert info.properties[b'prop_string'] == properties['prop_string'] + assert info.properties[b'prop_float'] == b'1.0' + assert info.properties[b'prop_blank'] == properties['prop_blank'] + assert info.properties[b'prop_true'] == b'1' + assert info.properties[b'prop_false'] == b'0' + assert info.addresses == addresses[:1] # no V6 by default + assert set(info.addresses_by_version(r.IPVersion.All)) == set(addresses) + + cached_info = ServiceInfo(type_, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + + # Populate the cache + zeroconf_browser.get_service_info(subtype, registration_name) + + # get service info with only the cache + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + + # get service info with only the cache with the lowercase name + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + # Ensure uppercase output is preserved + assert cached_info.name == registration_name + assert cached_info.key == registration_name.lower() + assert cached_info.properties is not None + assert cached_info.properties[b'prop_float'] == b'1.0' + + info = zeroconf_browser.get_service_info(subtype, registration_name) + assert info is not None + assert info.properties is not None + assert info.properties[b'prop_none'] is None + + cached_info = ServiceInfo(subtype, registration_name.lower()) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_none'] is None + + # test TXT record update + sublistener = MySubListener() + zeroconf_browser.add_service_listener(registration_name, sublistener) + properties['prop_blank'] = b'an updated string' + desc.update(properties) + info_service = ServiceInfo( + subtype, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + zeroconf_registrar.update_service(info_service) + service_updated.wait(1) + assert service_updated.is_set() + + info = zeroconf_browser.get_service_info(type_, registration_name) + assert info is not None + assert info.properties[b'prop_blank'] == properties['prop_blank'] + + cached_info = ServiceInfo(subtype, registration_name) + cached_info.load_from_cache(zeroconf_browser) + assert cached_info.properties is not None + assert cached_info.properties[b'prop_blank'] == properties['prop_blank'] + + zeroconf_registrar.unregister_service(info_service) + service_removed.wait(1) + assert service_removed.is_set() + + finally: + zeroconf_registrar.close() + zeroconf_browser.remove_service_listener(listener) + zeroconf_browser.close() def test_servicelisteners_raise_not_implemented(): From f0770fea80b00f2340815fa983968f68a15c702e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 16:19:11 -1000 Subject: [PATCH 491/608] Break apart net_socket for easier testing (#875) --- tests/utils/test_net.py | 17 ++++++++++++++++ zeroconf/_utils/net.py | 45 +++++++++++++++++++++++------------------ 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 16c2b485..7c445b47 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -139,3 +139,20 @@ def test_set_so_reuseport_if_available_not_present(): sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) with patch("socket.socket.setsockopt", side_effect=OSError): netutils.set_so_reuseport_if_available(sock) + + +def test_set_mdns_port_socket_options_for_ip_version(): + """Test OSError with errno with EINVAL and bind address '' from setsockopt IP_MULTICAST_TTL does not raise.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + # Should raise on EPERM always + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only) + + # Should raise on EINVAL always when bind address is not '' + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('127.0.0.1',), r.IPVersion.V4Only) + + # Should not raise on EINVAL when bind address is '' + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index b30c828a..937dc116 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -189,6 +189,28 @@ def set_so_reuseport_if_available(s: socket.socket) -> None: raise +def set_mdns_port_socket_options_for_ip_version( + s: socket.socket, bind_addr: Union[Tuple[str], Tuple[str, int, int]], ip_version: IPVersion +) -> None: + """Set ttl/hops and loop for mdns port.""" + if ip_version != IPVersion.V6Only: + ttl = struct.pack(b'B', 255) + loop = struct.pack(b'B', 1) + # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and + # IP_MULTICAST_LOOP socket options as an unsigned char. + try: + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) + s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) + except socket.error as e: + if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS + raise + + if ip_version != IPVersion.V4Only: + # However, char doesn't work here (at least on Linux) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) + s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) + + def new_socket( # pylint: disable=too-many-branches bind_addr: Union[Tuple[str], Tuple[str, int, int]], port: int = _MDNS_PORT, @@ -202,34 +224,17 @@ def new_socket( # pylint: disable=too-many-branches apple_p2p, bind_addr, ) - if ip_version == IPVersion.V4Only: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - else: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + socket_family = socket.AF_INET if ip_version == IPVersion.V4Only else socket.AF_INET6 + s = socket.socket(socket_family, socket.SOCK_DGRAM) if ip_version == IPVersion.All: disable_ipv6_only_or_raise(s) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - set_so_reuseport_if_available(s) if port == _MDNS_PORT: - ttl = struct.pack(b'B', 255) - loop = struct.pack(b'B', 1) - if ip_version != IPVersion.V6Only: - # OpenBSD needs the ttl and loop values for the IP_MULTICAST_TTL and - # IP_MULTICAST_LOOP socket options as an unsigned char. - try: - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) - s.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, loop) - except socket.error as e: - if bind_addr[0] != '' or get_errno(e) != errno.EINVAL: # Fails to set on MacOS - raise - if ip_version != IPVersion.V4Only: - # However, char doesn't work here (at least on Linux) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 255) - s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) + set_mdns_port_socket_options_for_ip_version(s, bind_addr, ip_version) if apple_p2p: # SO_RECV_ANYIF = 0x1104 From ab83819ad6b6ff727a894271dde3e4be6c28cb2c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 16:28:08 -1000 Subject: [PATCH 492/608] Add coverge for disconnected adapters in add_multicast_member (#877) --- tests/utils/test_net.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 7c445b47..399bd6ac 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -156,3 +156,32 @@ def test_set_mdns_port_socket_options_for_ip_version(): # Should not raise on EINVAL when bind address is '' with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): netutils.set_mdns_port_socket_options_for_ip_version(sock, ('',), r.IPVersion.V4Only) + + +def test_add_multicast_member(): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + interface = '127.0.0.1' + + # EPERM should always raise + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.EPERM, None)): + netutils.add_multicast_member(sock, interface) + + # EADDRINUSE should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRINUSE, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # EADDRNOTAVAIL should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EADDRNOTAVAIL, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # EINVAL should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.EINVAL, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # ENOPROTOOPT should return False + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)): + assert netutils.add_multicast_member(sock, interface) is False + + # No error should return True + with patch("socket.socket.setsockopt"): + assert netutils.add_multicast_member(sock, interface) is True From 86e2ab9db3c7bd47b6e81837d594280ced3b30f9 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 16:43:19 -1000 Subject: [PATCH 493/608] Add coverage to ensure loading zeroconf._logger does not override logging level (#878) --- tests/test_logger.py | 18 +++++++++++++++++- zeroconf/_logger.py | 9 +++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index 2c661cf9..205ce0ff 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -4,8 +4,24 @@ """Unit tests for logger.py.""" +import logging from unittest.mock import patch -from zeroconf._logger import QuietLogger +from zeroconf._logger import QuietLogger, set_logger_level_if_unset + + +def test_loading_logger(): + """Test loading logger does not change level unless it is unset.""" + log = logging.getLogger('zeroconf') + log.setLevel(logging.CRITICAL) + set_logger_level_if_unset() + log = logging.getLogger('zeroconf') + assert log.level == logging.CRITICAL + + log = logging.getLogger('zeroconf') + log.setLevel(logging.NOTSET) + set_logger_level_if_unset() + log = logging.getLogger('zeroconf') + assert log.level == logging.WARNING def test_log_warning_once(): diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py index 3577bb05..78c21148 100644 --- a/zeroconf/_logger.py +++ b/zeroconf/_logger.py @@ -27,8 +27,13 @@ log = logging.getLogger(__name__.split('.')[0]) log.addHandler(logging.NullHandler()) -if log.level == logging.NOTSET: - log.setLevel(logging.WARN) + +def set_logger_level_if_unset() -> None: + if log.level == logging.NOTSET: + log.setLevel(logging.WARN) + + +set_logger_level_if_unset() class QuietLogger: From be1d3bbe0ee12254d11e3d8b75c2faba950fabce Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 28 Jun 2021 21:29:50 -1000 Subject: [PATCH 494/608] Update changelog (#879) --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index 3943d9e2..5aec201d 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,13 @@ See examples directory for more. Changelog ========= +0.32.0 Release Candidate 7 +========================== + +This release offers 100% line and branch coverage + +* Break apart new_socket for easier testing (#875) @bdraco + 0.32.0 Release Candidate 6 ========================== From b9eae5a6f8f86bfe60446f133cad5fc33d072959 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 08:09:01 -1000 Subject: [PATCH 495/608] Revert name change of zeroconf.asyncio to zeroconf.aio (#885) - Now that `__init__.py` no longer needs to import `asyncio`, the name conflict is not a concern. Fixes #883 --- docs/api.rst | 2 +- examples/async_apple_scanner.py | 2 +- examples/async_browser.py | 2 +- examples/async_registration.py | 2 +- examples/async_service_info_request.py | 2 +- tests/services/test_browser.py | 2 +- tests/services/test_info.py | 2 +- tests/{test_aio.py => test_asyncio.py} | 6 +++--- tests/test_core.py | 2 +- tests/test_handlers.py | 2 +- tests/utils/{test_aio.py => test_asyncio.py} | 6 +++--- zeroconf/_core.py | 2 +- zeroconf/_services/browser.py | 2 +- zeroconf/_services/info.py | 2 +- zeroconf/_utils/{aio.py => asyncio.py} | 0 zeroconf/{aio.py => asyncio.py} | 0 16 files changed, 18 insertions(+), 18 deletions(-) rename tests/{test_aio.py => test_asyncio.py} (99%) rename tests/utils/{test_aio.py => test_asyncio.py} (93%) rename zeroconf/_utils/{aio.py => asyncio.py} (100%) rename zeroconf/{aio.py => asyncio.py} (100%) diff --git a/docs/api.rst b/docs/api.rst index 20c53727..1704db5a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -6,7 +6,7 @@ python-zeroconf API reference :undoc-members: :show-inheritance: -.. automodule:: zeroconf.aio +.. automodule:: zeroconf.asyncio :members: :undoc-members: :show-inheritance: diff --git a/examples/async_apple_scanner.py b/examples/async_apple_scanner.py index 573640b0..f10f6ef6 100644 --- a/examples/async_apple_scanner.py +++ b/examples/async_apple_scanner.py @@ -8,7 +8,7 @@ from typing import Any, Optional, cast from zeroconf import DNSQuestionType, IPVersion, ServiceStateChange, Zeroconf -from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf HOMESHARING_SERVICE: str = "_appletv-v2._tcp.local." DEVICE_SERVICE: str = "_touch-able._tcp.local." diff --git a/examples/async_browser.py b/examples/async_browser.py index 85192e14..f0e0851c 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -11,7 +11,7 @@ from typing import Any, Optional, cast from zeroconf import IPVersion, ServiceStateChange, Zeroconf -from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes def async_on_service_state_change( diff --git a/examples/async_registration.py b/examples/async_registration.py index 7e02ea7c..53d14ce1 100644 --- a/examples/async_registration.py +++ b/examples/async_registration.py @@ -9,7 +9,7 @@ from typing import List from zeroconf import IPVersion -from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf async def register_services(infos: List[AsyncServiceInfo]) -> None: diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py index 8ea961eb..885eb99c 100644 --- a/examples/async_service_info_request.py +++ b/examples/async_service_info_request.py @@ -13,7 +13,7 @@ from zeroconf import IPVersion, ServiceBrowser, ServiceStateChange, Zeroconf -from zeroconf.aio import AsyncServiceInfo, AsyncZeroconf +from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf HAP_TYPE = "_hap._tcp.local." diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 36f459c7..26684e09 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -22,7 +22,7 @@ from zeroconf._services import ServiceStateChange from zeroconf._services.browser import ServiceBrowser from zeroconf._services.info import ServiceInfo -from zeroconf.aio import AsyncZeroconf +from zeroconf.asyncio import AsyncZeroconf from .. import has_working_ipv6, _inject_response, _wait_for_start diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 6a5ae428..02ba581e 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -18,7 +18,7 @@ import zeroconf as r from zeroconf import DNSAddress, const from zeroconf._services.info import ServiceInfo -from zeroconf.aio import AsyncZeroconf +from zeroconf.asyncio import AsyncZeroconf from .. import has_working_ipv6, _inject_response diff --git a/tests/test_aio.py b/tests/test_asyncio.py similarity index 99% rename from tests/test_aio.py rename to tests/test_asyncio.py index fb3f07ea..759ab5b3 100644 --- a/tests/test_aio.py +++ b/tests/test_asyncio.py @@ -14,7 +14,7 @@ import pytest -from zeroconf.aio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes +from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf, AsyncZeroconfServiceTypes from zeroconf import ( DNSIncoming, DNSOutgoing, @@ -422,7 +422,7 @@ async def test_service_info_async_request() -> None: _clear_cache(aiozc.zeroconf) # Generating the race condition is almost impossible # without patching since its a TOCTOU race - with patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): + with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): await aiosinfo.async_request(aiozc.zeroconf, 3000) assert aiosinfo is not None assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] @@ -826,7 +826,7 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): with patch.object(zeroconf_info, "async_send", send): aiosinfo = AsyncServiceInfo(type_, registration_name) # Patch _is_complete so we send multiple times - with patch("zeroconf.aio.AsyncServiceInfo._is_complete", False): + with patch("zeroconf.asyncio.AsyncServiceInfo._is_complete", False): await aiosinfo.async_request(aiozc.zeroconf, 1200) try: assert first_outgoing.questions[0].unicast == True diff --git a/tests/test_core.py b/tests/test_core.py index 2a3f368b..d80514f7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,7 +20,7 @@ import zeroconf as r from zeroconf import _core, _protocol, const, Zeroconf, current_time_millis -from zeroconf.aio import AsyncZeroconf +from zeroconf.asyncio import AsyncZeroconf from . import has_working_ipv6, _clear_cache, _inject_response, _wait_for_start diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 8fe2c56d..bab50a55 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -17,7 +17,7 @@ from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const from zeroconf._dns import DNSRRSet -from zeroconf.aio import AsyncZeroconf +from zeroconf.asyncio import AsyncZeroconf from . import _clear_cache, _inject_response diff --git a/tests/utils/test_aio.py b/tests/utils/test_asyncio.py similarity index 93% rename from tests/utils/test_aio.py rename to tests/utils/test_asyncio.py index 524fd973..ccafb72f 100644 --- a/tests/utils/test_aio.py +++ b/tests/utils/test_asyncio.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- -"""Unit tests for zeroconf._utils.aio.""" +"""Unit tests for zeroconf._utils.asyncio.""" import asyncio import contextlib @@ -12,7 +12,7 @@ import pytest -from zeroconf._utils import aio as aioutils +from zeroconf._utils import asyncio as aioutils @pytest.mark.asyncio @@ -25,7 +25,7 @@ async def test_async_get_all_tasks() -> None: await aioutils._async_get_all_tasks(aioutils.get_running_loop()) if not hasattr(asyncio, 'all_tasks'): return - with patch("zeroconf._utils.aio.asyncio.all_tasks", side_effect=RuntimeError): + with patch("zeroconf._utils.asyncio.asyncio.all_tasks", side_effect=RuntimeError): await aioutils._async_get_all_tasks(aioutils.get_running_loop()) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index ef4d4d70..a20e5639 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -43,7 +43,7 @@ from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener -from ._utils.aio import get_running_loop, shutdown_loop, wait_event_or_timeout +from ._utils.asyncio import get_running_loop, shutdown_loop, wait_event_or_timeout from ._utils.name import service_type_name from ._utils.net import ( IPVersion, diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 5f2dbd31..fecf35c9 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -38,7 +38,7 @@ SignalRegistrationInterface, ) from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.aio import get_best_available_queue +from .._utils.asyncio import get_best_available_queue from .._utils.name import service_type_name from .._utils.time import current_time_millis, millis_to_seconds from ..const import ( diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 52dabd2b..3e371b17 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -29,7 +29,7 @@ from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.aio import get_running_loop +from .._utils.asyncio import get_running_loop from .._utils.name import service_type_name from .._utils.net import ( IPVersion, diff --git a/zeroconf/_utils/aio.py b/zeroconf/_utils/asyncio.py similarity index 100% rename from zeroconf/_utils/aio.py rename to zeroconf/_utils/asyncio.py diff --git a/zeroconf/aio.py b/zeroconf/asyncio.py similarity index 100% rename from zeroconf/aio.py rename to zeroconf/asyncio.py From b9dc12dee8b4a7f6d8e1f599948bf16e5e7fab47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 08:09:21 -1000 Subject: [PATCH 496/608] Disable pylint in the CI (#886) --- .github/workflows/ci.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d4f9a35..01b181fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,11 +56,6 @@ jobs: run: | ${{ matrix.venvcmd }} make black_check - - name: Run pylint - if: ${{ runner.os == 'Linux' && matrix.python-version != 'pypy3' }} - run: | - ${{ matrix.venvcmd }} - make pylint - name: Run tests run: | ${{ matrix.venvcmd }} From 14cf9362c9ae947bcee5911b9c593ca76f50d529 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 08:29:49 -1000 Subject: [PATCH 497/608] Collapse changelog for 0.32.0 (#887) --- README.rst | 531 +---------------------------------------------------- 1 file changed, 7 insertions(+), 524 deletions(-) diff --git a/README.rst b/README.rst index 5aec201d..536ddc5a 100644 --- a/README.rst +++ b/README.rst @@ -140,31 +140,14 @@ See examples directory for more. Changelog ========= -0.32.0 Release Candidate 7 -========================== +0.32.0 (Unreleased) +=================== -This release offers 100% line and branch coverage - -* Break apart new_socket for easier testing (#875) @bdraco - -0.32.0 Release Candidate 6 -========================== - -* Fix deadlock when event loop is shutdown during service registration (#869) @bdraco - -* Break apart new_socket to be testable (#867) @bdraco - -0.32.0 Release Candidate 5 -========================== - -* Ensure protocol and sending errors are logged once (#862) @bdraco - -* Remove unreachable code in AsyncListener.datagram_received (#863) @bdraco - -* Make a dispatch dict for ServiceStateChange listeners (#859) @bdraco +Documentation for breaking changes era on the side of the caution and likely +overstates the risk on many of these. If you are not accessing zeroconf internals, +you can likely not be concerned with the breaking changes below: -0.32.0 Release Candidate 4 -========================== +This release offers 100% line and branch coverage * Make ServiceInfo first question QU (#852) @bdraco @@ -185,22 +168,6 @@ This release offers 100% line and branch coverage also asks the first question as QU since ServiceInfo is commonly called from ServiceBrowser callbacks -0.32.0 Release Candidate 3 -========================== - -* Switch ServiceBrowser query scheduling to use call_later instead of a loop (#849) @bdraco - - Simplifies scheduling as there is no more need to sleep in a loop as - we now schedule future callbacks with call_later - - Simplifies cancelation as there is no more coroutine to cancel, only a timer handle - We no longer have to handle the canceled error and cleaning up the awaitable - - Solves the infrequent test failures in test_backoff and test_integration - -0.32.0 Release Candidate 2 -========================== - * Limit duplicate packet suppression to 1s intervals (#841) @bdraco Only suppress duplicate packets that happen within the same @@ -216,18 +183,6 @@ This release offers 100% line and branch coverage multi-packet known answer supression since it was not expecting to get the same data more than once - -0.32.0 Release Candidate 1 -========================== - -No changes - -0.32.0 Beta 6 -============= - -This beta addresses two potential areas where zeroconf can be overwhelmed and -deny service to legitimate queriers. - * BREAKING CHANGE: Drop oversize packets before processing them (#826) @bdraco Oversized packets can quickly overwhelm the system and deny @@ -242,16 +197,6 @@ deny service to legitimate queriers. Apple uses a 15s minimum TTL, however we do not have the same level of rate limit and safe guards so we use 1/4 of the recommended value. -0.32.0 Beta 5 -============= - -* Only wake up the query loop when there is a change in the next query time (#818) @bdraco - - The ServiceBrowser query loop (async_browser_task) was being awoken on - every packet because it was using `zeroconf.async_wait` which wakes - up on every new packet. We only need to awaken the loop when the next time - we are going to send a query has changed. - * New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 @@ -263,22 +208,6 @@ deny service to legitimate queriers. the network, and has the secondary advantage that most responders will answer a QU question without the typical delay answering QM questions. - -0.32.0 Beta 4 -============= - -* Simplify wait_event_or_timeout (#810) @bdraco - - This function always did the same thing on timeout and - wait complete so we can use the same callback. This - solves the CI failing due to the test coverage flapping - back and forth as the timeout would rarely happen. - -* Make DNSHinfo and DNSAddress use the same match order as DNSPointer and DNSText (#808) @bdraco - - We want to check the data that is most likely to be unique first - so we can reject the __eq__ as soon as possible. - * Qualify IPv6 link-local addresses with scope_id (#343) @ibygrave When a service is advertised on an IPv6 address where @@ -290,28 +219,8 @@ deny service to legitimate queriers. return qualified addresses to avoid breaking compatibility on the existing parsed_addresses(). -0.32.0 Beta 3 -============= - * Skip network adapters that are disconnected (#327) @ZLJasonG -* Add slots to DNS classes (#803) @bdraco - - On a busy network that receives many mDNS packets per second, we - will not know the answer to most of the questions being asked. - In this case the creating the DNS* objects are usually garbage - collected within 1s as they are not needed. We now set __slots__ - to speed up the creation and destruction of these objects - -0.32.0 Beta 2 -============= - -* Ensure we handle threadsafe shutdown under PyPy with multiple event loops (#800) @bdraco - -* Ensure fresh ServiceBrowsers see old_record as None when replaying the cache (#793) @bdraco - - This is fixing ServiceBrowser missing an add when the record is already in the cache. - * Pass both the new and old records to async_update_records (#792) @bdraco Pass the old_record (cached) as the value and the new_record (wire) @@ -320,23 +229,6 @@ deny service to legitimate queriers. when generating the async_update_records call. This avoids the overhead of multiple cache lookups for each listener. -* Make add_listener and remove_listener threadsafe (#794) @bdraco - -* Ensure outgoing ServiceBrowser questions are seen by the question history (#790) @bdraco - -0.32.0 Beta 1 -============= - -Documentation for breaking changes era on the side of the caution and likely -overstates the risk on many of these. If you are not accessing zeroconf internals, -you can likely not be concerned with the breaking changes below: - -* BREAKING CHANGE: zeroconf.asyncio has been renamed zeroconf.aio (#503) @bdraco - - The asyncio name could shadow system asyncio in some cases. If - zeroconf is in sys.path, this would result in loading zeroconf.asyncio - when system asyncio was intended. - * BREAKING CHANGE: Update internal version check to match docs (3.6+) (#491) @bdraco Python version eariler then 3.6 were likely broken with zeroconf @@ -442,71 +334,6 @@ you can likely not be concerned with the breaking changes below: * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco -* Add async_apple_scanner example (#719) @bdraco - -* Add support for requesting QU questions to ServiceBrowser and ServiceInfo (#787) @bdraco - -* Ensure the queue is created before adding listeners to ServiceBrowser (#785) @bdraco - - The callback from the listener could generate an event that would - fire in async context that should have gone to the queue which - could result in the consumer running a sync call in the event loop - and blocking it. - -* Add a guard to prevent running ServiceInfo.request in async context (#784) @bdraco - -* Inline utf8 decoding when processing incoming packets (#782) @bdraco - -* Drop utf cache from _dns (#781) (later reverted) @bdraco - -* Switch to using a simple cache instead of lru_cache (#779) (later reverted) @bdraco - -* Fix Responding to Address Queries (RFC6762 section 6.2) (#777) @bdraco - -* Fix deadlock on ServiceBrowser shutdown with PyPy (#774) @bdraco - -* Add a guard against the task list changing when shutting down (#776) @bdraco - -* Improve performance of parsing DNSIncoming by caching read_utf (#769) (later reverted) @bdraco - -* Switch to using an asyncio.Event for async_wait (#759) @bdraco - - We no longer need to check for thread safety under a asyncio.Condition - as the ServiceBrowser and ServiceInfo internals schedule coroutines - in the eventloop. - -* Simplify ServiceBrowser callsbacks (#756) @bdraco - -* Revert: Fix thread safety in _ServiceBrowser.update_records_complete (#708) (#755) @bdraco - -- This guarding is no longer needed as the ServiceBrowser loop - now runs in the event loop and the thread safety guard is no - longer needed - -* Drop AsyncServiceListener (#754) @bdraco (Never shipped) - -* Run ServiceBrowser queries in the event loop (#752) @bdraco - -* Remove unused argument from AsyncZeroconf (#751) @bdraco - -* Fix warning about Zeroconf._async_notify_all not being awaited in sync shutdown (#750) @bdraco - -* Update async_service_info_request example to ensure it runs in the right event loop (#749) @bdraco - -* Run ServiceInfo requests in the event loop (#748) @bdraco - -* Remove support for notify listeners (#733) @bdraco (Never shipped) - -* Relocate service browser tests to tests/services/test_browser.py (#745) @bdraco - -* Relocate ServiceInfo to zeroconf._services.info (#741) @bdraco - -* Run question answer callbacks from add_listener in the event loop (#740) @bdraco - - Calling async_update_records and async_update_records_complete should always - happen in the event loop to ensure implementers do not need to worry about - thread safety - * Remove second level caching from ServiceBrowsers (#737) @bdraco The ServiceBrowser had its own cache of the last time it @@ -514,23 +341,6 @@ you can likely not be concerned with the breaking changes below: presenting a source of truth problem that lead to unexpected queries when the two disagreed. -* Breakout ServiceBrowser handler from listener creation (#736) @bdraco - - Add coverage for the handler from listener - -* Add fast cache lookup functions (#732) @bdraco - - The majority of our lookups happen in the event loop so there is no need - for them to be threadsafe. Now that the codebase is more clear about what - needs to be threadsafe and what does not need to be threadsafe we can use - the much faster non-threadsafe versions in the places where we are calling - from the event loop. - -* Switch to using DNSRRSet in RecordManager (#735) @bdraco - - DNSRRSet is able to do O(1) lookups of records assuming - there are no collisions. - * Fix server cache to be case-insensitive (#731) @bdraco If the server name had uppercase chars and any of the @@ -549,33 +359,12 @@ you can likely not be concerned with the breaking changes below: unique record and we never have a source of truth problem determining the TTL of a record from the cache. -* Rename handlers and internals to make it clear what is threadsafe (#726) @bdraco - - It was too easy to get confused about what was threadsafe and - what was not threadsafe which lead to unexpected failures. - Rename functions to make it clear what will be run in the event - loop and what is expected to be threadsafe - * Fix ServiceInfo with multiple A records (#725) @bdraco If there were multiple A records for the host, ServiceInfo would always return the last one that was in the incoming packet which was usually not the one that was wanted. -* Synchronize time for fate sharing (#718) @bdraco - -* Cleanup typing in zero._core and document ignores (#714) @bdraco - -* Cleanup typing in zeroconf._logger (#715) @bdraco - -* Cleanup typing in zeroconf._utils.net (#713) @bdraco - -* Cleanup typing in zeroconf._services (#711) @bdraco - -* Cleanup typing in zeroconf._services.registry (#712) @bdraco - -* Add setter for DNSQuestion to easily make a QU question (#710) @bdraco - * Set stale unique records to expire 1s in the future instead of instant removal (#706) @bdraco tools.ietf.org/html/rfc6762#section-10.2 @@ -588,44 +377,11 @@ you can likely not be concerned with the breaking changes below: cooperating responders one second to send out their own response to "rescue" the records before they expire and are deleted. -* Fix thread safety in _ServiceBrowser.update_records_complete (#708) @bdraco - -* Split DNSOutgoing/DNSIncoming/DNSMessage into zeroconf._protocol (#705) @bdraco - -* Abstract DNSOutgoing ttl write into _write_ttl (#695) @bdraco - -* Rollback data in one call instead of poping one byte at a time in DNS Outgoing (#696) @bdraco - * Suppress additionals when answer is suppressed (#690) @bdraco -* Move setting DNS created and ttl into its own function (#692) @bdraco - -* Add truncated property to DNSMessage to lookup the TC bit (#686) @bdraco - -* Check if SO_REUSEPORT exists instead of using an exception catch (#682) @bdraco - -* Use DNSRRSet for known answer suppression (#680) @bdraco - - DNSRRSet uses hash table lookups under the hood which - is much faster than the linear searches used by - DNSRecord.suppressed_by - -* Add DNSRRSet class for quick hashtable lookups of records (#678) @bdraco - - This class will be used to do fast checks to see - if records should be suppressed by a set of answers. - * Allow unregistering a service multiple times (#679) @bdraco -* Remove unreachable BadTypeInNameException check in _ServiceBrowser (#677) @bdraco - -* Update async_browser.py example to use AsyncZeroconfServiceTypes (#665) @bdraco - -* Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.aio (#658) @bdraco - -* Remove all calls to the executor in AsyncZeroconf (#653) @bdraco - -* Set __all__ in zeroconf.aio to ensure private functions do now show in the docs (#652) @bdraco +* Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco * Ensure interface_index_to_ip6_address skips ipv4 adapters (#651) @bdraco @@ -639,137 +395,15 @@ you can likely not be concerned with the breaking changes below: time. To avoid this, we now remove the services from the registry right after we generate the goodbye packet -* Use ServiceInfo.key/ServiceInfo.server_key instead of lowering in ServiceRegistry (#647) @bdraco - -* Ensure the ServiceInfo.key gets updated when the name is changed externally (#645) @bdraco - -* Ensure AsyncZeroconf.async_close can be called multiple times like Zeroconf.close (#638) @bdraco - -* Ensure eventloop shutdown is threadsafe (#636) @bdraco - -* Return early in the shutdown/close process (#632) @bdraco - -* Remove unreachable cache check for DNSAddresses (#629) @bdraco - - The ServiceBrowser would check to see if a DNSAddress was - already in the cache and return early to avoid sending - updates when the address already was held in the cache. - This check was not needed since there is already a check - a few lines before as `self.zc.cache.get(record)` which - effectively does the same thing. This lead to the check - never being covered in the tests and 2 cache lookups when - only one was needed. - -* Add test for wait_condition_or_timeout_times_out util (#630) @bdraco - -* Return early on invalid data received (#628) @bdraco - - Improve coverage for handling invalid incoming data - -* Add test to ensure ServiceBrowser sees port change as an update (#625) @bdraco - -* Fix random test failures due to monkey patching not being undone between tests (#626) @bdraco - - Switch patching to use unitest.mock.patch to ensure the patch - is reverted when the test is completed - * Ensure zeroconf can be loaded when the system disables IPv6 (#624) @bdraco -* Eliminate aio sender thread (#622) @bdraco - -* Replace select loop with asyncio loop (#504) @bdraco - -* Add is_recent property to DNSRecord (#620) @bdraco - - RFC 6762 defines recent as not multicast within one quarter of its TTL - datatracker.ietf.org/doc/html/rfc6762#section-5.4 - -* Breakout the query response handler into its own class (#615) @bdraco - -* Add the ability for ServiceInfo.dns_addresses to filter by address type (#612) @bdraco - -* Make DNSRecords hashable (#611) @bdraco - - Allows storing them in a set for de-duplication - - Needed to be able to check for duplicates to solve #604 - * Ensure the QU bit is set for probe queries (#609) @bdraco The bit should be set per datatracker.ietf.org/doc/html/rfc6762#section-8.1 -* Log destination when sending packets (#606) @bdraco - -* Fix docs version to match readme (cpython 3.6+) (#602) @bdraco - -* Add ZeroconfServiceTypes to zeroconf.__all__ (#601) @bdraco - - This class is in the readme, but is not exported by - default - -* Add id_ param to allow setting the id in the DNSOutgoing constructor (#599) @bdraco - -* Add unicast property to DNSQuestion to determine if the QU bit is set (#593) @bdraco - -* Reduce branching in DNSOutgoing.add_answer_at_time (#592) @bdraco - -* Breakout DNSCache into zeroconf.cache (#568) @bdraco - -* Removed protected imports from zeroconf namespace (#567) @bdraco - -* Fix invalid typing in ServiceInfo._set_text (#554) @bdraco - -* Move QueryHandler and RecordManager handlers into zeroconf.handlers (#551) @bdraco - -* Move ServiceListener to zeroconf.services (#550) @bdraco - -* Move the ServiceRegistry into its own module (#549) @bdraco - -* Move ServiceStateChange to zeroconf.services (#548) @bdraco - -* Relocate core functions into zeroconf.core (#547) @bdraco - -* Breakout service classes into zeroconf.services (#544) @bdraco - -* Move service_type_name to zeroconf.utils.name (#543) @bdraco - -* Relocate DNS classes to zeroconf.dns (#541) @bdraco - -* Update zeroconf.aio import locations (#539) @bdraco - -* Move int2byte to zeroconf.utils.struct (#540) @bdraco - -* Breakout network utils into zeroconf.utils.net (#537) @bdraco - -* Move time utility functions into zeroconf.utils.time (#536) @bdraco - -* Avoid making DNSOutgoing aware of the Zeroconf object (#535) @bdraco - -* Move logger into zeroconf.logger (#533) @bdraco - -* Move exceptions into zeroconf.exceptions (#532) @bdraco - -* Move constants into const.py (#531) @bdraco - -* Move asyncio utils into zeroconf.utils.aio (#530) @bdraco - -* Move ipversion auto detection code into its own function (#524) @bdraco - * Breaking change: Update python compatibility as PyPy3 7.2 is required (#523) @bdraco -* Remove broad exception catch from RecordManager.remove_listener (#517) @bdraco - -* Small cleanups to RecordManager.add_listener (#516) @bdraco - -* Move RecordUpdateListener management into RecordManager (#514) @bdraco - -* Break out record updating into RecordManager (#512) @bdraco - -* Remove uneeded wait in the Engine thread (#511) @bdraco - -* Extract code for handling queries into QueryHandler (#507) @bdraco - * Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco * Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco @@ -783,142 +417,16 @@ you can likely not be concerned with the breaking changes below: exceeds _MAX_MSG_TYPICAL datatracker.ietf.org/doc/html/rfc6762#section-17 -* Make a base class for DNSIncoming and DNSOutgoing (#497) @bdraco - -* Remove unused __ne__ code from Python 2 era (#492) @bdraco - -* Lint before testing in the CI (#488) @bdraco - -* Add AsyncServiceBrowser example (#487) @bdraco - -* Move threading daemon property into ServiceBrowser class (#486) @bdraco - -* Enable test_integration_with_listener_class test on PyPy (#485) @bdraco - -* AsyncServiceBrowser must recheck for handlers to call when holding condition (#483) - - There was a short race condition window where the AsyncServiceBrowser - could add to _handlers_to_call in the Engine thread, have the - condition notify_all called, but since the AsyncServiceBrowser was - not yet holding the condition it would not know to stop waiting - and process the handlers to call. - -* Relocate ServiceBrowser wait time calculation to seperate function (#484) @bdraco - - Eliminate the need to duplicate code between the ServiceBrowser - and AsyncServiceBrowser to calculate the wait time. - -* Switch from using an asyncio.Event to asyncio.Condition for waiting (#482) @bdraco - -* ServiceBrowser must recheck for handlers to call when holding condition (#477) @bdraco - - There was a short race condition window where the ServiceBrowser - could add to _handlers_to_call in the Engine thread, have the - condition notify_all called, but since the ServiceBrowser was - not yet holding the condition it would not know to stop waiting - and process the handlers to call. - -* Provide a helper function to convert milliseconds to seconds (#481) @bdraco - -* Fix AsyncServiceInfo.async_request not waiting long enough (#480) @bdraco - -* Add support for updating multiple records at once to ServiceInfo (#474) @bdraco - -* Narrow exception catch in DNSAddress.__repr__ to only expected exceptions (#473) @bdraco - -* Add test coverage to ensure ServiceInfo rejects expired records (#468) @bdraco - -* Reduce branching in service_type_name (#472) @bdraco - -* Fix flakey test_update_record (#470) @bdraco - -* Reduce branching in Zeroconf.handle_response (#467) @bdraco - * Ensure PTR questions asked in uppercase are answered (#465) @bdraco -* Clear cache between ServiceTypesQuery tests (#466) @bdraco - -* Break apart Zeroconf.handle_query to reduce branching (#462) @bdraco - * Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek -* Use constant for service type enumeration (#461) @bdraco - -* Reduce branching in Zeroconf.handle_response (#459) @bdraco - -* Reduce branching in Zeroconf.handle_query (#460) @bdraco - -* Enable pylint (#438) @bdraco - -* Trap OSError directly in Zeroconf.send instead of checking isinstance (#453) @bdraco - -* Disable protected-access on the ServiceBrowser usage of _handlers_lock (#452) @bdraco - -* Mark functions with too many branches in need of refactoring (#455) @bdraco - -* Disable pylint no-self-use check on abstract methods (#451) @bdraco - -* Use unique name in test_async_service_browser test (#450) @bdraco - -* Disable no-member check for WSAEINVAL false positive (#454) @bdraco - -* Mark methods used by asyncio without self use (#447) @bdraco - -* Extract _get_queue from zeroconf.asyncio._AsyncSender (#444) @bdraco - -* Fix redefining argument with the local name 'record' in ServiceInfo.update_record (#448) @bdraco - -* Remove unneeded-not in new_socket (#445) @bdraco - -* Disable broad except checks in places we still catch broad exceptions (#443) @bdraco - -* Merge _TYPE_CNAME and _TYPE_PTR comparison in DNSIncoming.read_others (#442) @bdraco - -* Convert unnecessary use of a comprehension to a list (#441) @bdraco - -* Remove unused now argument from ServiceInfo._process_record (#440) @bdraco - -* Disable pylint too-many-branches for functions that need refactoring (#439) @bdraco - -* Cleanup unused variables (#437) @bdraco - -* Cleanup unnecessary else after returns (#436) @bdraco - -* Add zeroconf.asyncio to the docs (#434) @bdraco - -* Fix warning when generating sphinx docs (#432) @bdraco - * Implement an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco -* Seperate non-thread specific code from ServiceBrowser into _ServiceBrowserBase (#428) @bdraco - -* Remove is_type_unique as it is unused (#426) - -* Avoid checking the registry when answering requests for _services._dns-sd._udp.local. (#425) @bdraco - - _services._dns-sd._udp.local. is a special case and should never - be in the registry - -* Remove unused argument from ServiceInfo.dns_addresses (#423) @bdraco - -* Add methods to generate DNSRecords from ServiceInfo (#422) @bdraco - -* Seperate logic for consuming records in ServiceInfo (#421) @bdraco - -* Seperate query generation for ServiceBrowser (#420) @bdraco - -* Add async_request example with browse (#415) @bdraco - -* Add async_register_service/async_unregister_service example (#414) @bdraco - * Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco -* Add support for registering notify listeners (#409) @bdraco - * Allow passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco -* Use a dedicated thread for sending outgoing packets with asyncio (#404) @bdraco - * Fix IPv6 setup under MacOS when binding to "" (#392) @bdraco * Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco @@ -928,17 +436,6 @@ you can likely not be concerned with the breaking changes below: the .join() was never waited for when a new Zeroconf object was created -* Simplify DNSPointer processing in ServiceBrowser (#386) @bdraco - -* Ensure the cache is checked for name conflict after final service query with asyncio (#382) @bdraco - -* Complete ServiceInfo request as soon as all questions are answered (#380) @bdraco - - Closes a small race condition where there were no questions - to ask because the cache was populated in between checks - -* Coalesce browser questions scheduled at the same time (#379) @bdraco - * Ensure duplicate packets do not trigger duplicate updates (#376) @bdraco If TXT or SRV records update was already processed and then @@ -951,12 +448,6 @@ you can likely not be concerned with the breaking changes below: * Reduce length of ServiceBrowser thread name with many types (#373) @bdraco -* Remove Callable quoting (#371) @bdraco - -* Abstract check to see if a record matches a type the ServiceBrowser wants (#369) @bdraco - -* Reduce complexity of ServiceBrowser enqueue_callback (#368) @bdraco - * Fix empty answers being added in ServiceInfo.request (#367) @bdraco * Ensure ServiceInfo populates all AAAA records (#366) @bdraco @@ -970,10 +461,6 @@ you can likely not be concerned with the breaking changes below: Move duplicate code that checked if the ServiceInfo was complete into its own function -* Remove black python 3.5 exception block (#365) @bdraco - -* Small cleanup of ServiceInfo.update_record (#364) @bdraco - * Add new cache function get_all_by_details (#363) @bdraco When working with IPv6, multiple AAAA records can exist for a given host. get_by_details would only return the @@ -982,10 +469,6 @@ you can likely not be concerned with the breaking changes below: Fix a case where the cache list can change during iteration -* Small cleanups to asyncio tests (#362) @bdraco - -* Improve test coverage for name conflicts (#357) @bdraco - * Return task objects created by AsyncZeroconf (#360) @nocarryr 0.31.0 From d31fd103cc942574f7fbc75e5346cc3d3eaf7ee1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 08:36:03 -1000 Subject: [PATCH 498/608] Remove extra newlines between changelog entries (#888) --- README.rst | 55 ------------------------------------------------------ 1 file changed, 55 deletions(-) diff --git a/README.rst b/README.rst index 536ddc5a..4b39b591 100644 --- a/README.rst +++ b/README.rst @@ -167,14 +167,12 @@ This release offers 100% line and branch coverage This change puts ServiceInfo inline with ServiceBrowser which also asks the first question as QU since ServiceInfo is commonly called from ServiceBrowser callbacks - * Limit duplicate packet suppression to 1s intervals (#841) @bdraco Only suppress duplicate packets that happen within the same second. Legitimate queriers will retry the question if they are suppressed. The limit was reduced to one second to be in line with rfc6762 - * Make multipacket known answer suppression per interface (#836) @bdraco The suppression was happening per instance of Zeroconf instead @@ -182,21 +180,18 @@ This release offers 100% line and branch coverage interfaces (usually and wifi and ethernet), this would confuse the multi-packet known answer supression since it was not expecting to get the same data more than once - * BREAKING CHANGE: Drop oversize packets before processing them (#826) @bdraco Oversized packets can quickly overwhelm the system and deny service to legitimate queriers. In practice this is usually due to broken mDNS implementations rather than malicious actors. - * BREAKING CHANGE: Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) @bdraco We now enforce a minimum TTL for PTR records to avoid ServiceBrowsers generating excessive queries refresh queries. Apple uses a 15s minimum TTL, however we do not have the same level of rate limit and safe guards so we use 1/4 of the recommended value. - * New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 @@ -207,7 +202,6 @@ This release offers 100% line and branch coverage a breaking change to increase). This reduces the amount of traffic on the network, and has the secondary advantage that most responders will answer a QU question without the typical delay answering QM questions. - * Qualify IPv6 link-local addresses with scope_id (#343) @ibygrave When a service is advertised on an IPv6 address where @@ -218,9 +212,7 @@ This release offers 100% line and branch coverage A new API `parsed_scoped_addresses()` is provided to return qualified addresses to avoid breaking compatibility on the existing parsed_addresses(). - * Skip network adapters that are disconnected (#327) @ZLJasonG - * Pass both the new and old records to async_update_records (#792) @bdraco Pass the old_record (cached) as the value and the new_record (wire) @@ -228,12 +220,10 @@ This release offers 100% line and branch coverage check the cache since we will always have the old_record when generating the async_update_records call. This avoids the overhead of multiple cache lookups for each listener. - * BREAKING CHANGE: Update internal version check to match docs (3.6+) (#491) @bdraco Python version eariler then 3.6 were likely broken with zeroconf already, however the version is now explictly checked. - * BREAKING CHANGE: RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco This allows the listener to receive all the records that have @@ -258,7 +248,6 @@ This release offers 100% line and branch coverage I/O. Before 0.32+ these functions ran in a select() loop and should not have been doing any blocking I/O, but it was not clear to implementors that I/O would block the loop. - * BREAKING CHANGE: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco When manually creating a zeroconf.Engine object, it is no longer started automatically. @@ -266,7 +255,6 @@ This release offers 100% line and branch coverage The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. - * BREAKING CHANGE: Remove DNSOutgoing.packet backwards compatibility (#569) @bdraco DNSOutgoing.packet only returned a partial message when the @@ -275,12 +263,10 @@ This release offers 100% line and branch coverage which always returns a complete payload in #248 As packet() should not be used since it will end up missing data, it has been removed - * BREAKING CHANGE: Mark DNSOutgoing write functions as protected (#633) @bdraco These functions are not intended to be used by external callers and the API is not likely to be stable in the future - * BREAKING CHANGE: Prefix cache functions that are non threadsafe with async_ (#724) @bdraco Adding (`zc.cache.add` -> `zc.cache.async_add_records`), removing (`zc.cache.remove` -> @@ -293,36 +279,26 @@ This release offers 100% line and branch coverage We never expect these functions will be called externally, however it was possible so this is documented as a breaking change. It is highly recommended that external callers do not modify the cache directly. - * TRAFFIC REDUCTION: Add support for handling QU questions (#621) @bdraco Implements RFC 6762 sec 5.4: Questions Requesting Unicast Responses datatracker.ietf.org/doc/html/rfc6762#section-5.4 - * TRAFFIC REDUCTION: Protect the network against excessive packet flooding (#619) @bdraco - * TRAFFIC REDUCTION: Suppress additionals when they are already in the answers section (#617) @bdraco - * TRAFFIC REDUCTION: Avoid including additionals when the answer is suppressed by known-answer supression (#614) @bdraco - * TRAFFIC REDUCTION: Implement multi-packet known answer supression (#687) @bdraco Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 - * TRAFFIC REDUCTION: Efficiently bucket queries with known answers (#698) @bdraco - * TRAFFIC REDUCTION: Implement duplicate question supression (#770) @bdraco http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 - * MAJOR BUG: Ensure matching PTR queries are returned with the ANY query (#618) @bdraco - * MAJOR BUG: Fix lookup of uppercase names in registry (#597) @bdraco If the ServiceInfo was registered with an uppercase name and the query was for a lowercase name, it would not be found and vice-versa. - * MAJOR BUG: Ensure unicast responses can be sent to any source port (#598) @bdraco Unicast responses were only being sent if the source port @@ -331,22 +307,18 @@ This release offers 100% line and branch coverage dig -p 5353 @224.0.0.251 media-12.local The above query will now see a response - * MAJOR BUG: Fix queries for AAAA records (#616) @bdraco - * Remove second level caching from ServiceBrowsers (#737) @bdraco The ServiceBrowser had its own cache of the last time it saw a service which was reimplementing the DNSCache and presenting a source of truth problem that lead to unexpected queries when the two disagreed. - * Fix server cache to be case-insensitive (#731) @bdraco If the server name had uppercase chars and any of the matching records were lowercase, the server would not be found - * Fix cache handling of records with different TTLs (#729) @bdraco There should only be one unique record in the cache at @@ -358,13 +330,11 @@ This release offers 100% line and branch coverage to ensure that the newest record always replaces the same unique record and we never have a source of truth problem determining the TTL of a record from the cache. - * Fix ServiceInfo with multiple A records (#725) @bdraco If there were multiple A records for the host, ServiceInfo would always return the last one that was in the incoming packet which was usually not the one that was wanted. - * Set stale unique records to expire 1s in the future instead of instant removal (#706) @bdraco tools.ietf.org/html/rfc6762#section-10.2 @@ -376,17 +346,11 @@ This release offers 100% line and branch coverage incorrectly sends goodbye packets for its records, it gives the other cooperating responders one second to send out their own response to "rescue" the records before they expire and are deleted. - * Suppress additionals when answer is suppressed (#690) @bdraco - * Allow unregistering a service multiple times (#679) @bdraco - * Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco - * Ensure interface_index_to_ip6_address skips ipv4 adapters (#651) @bdraco - * Add async_unregister_all_services to AsyncZeroconf (#649) @bdraco - * Ensure services are removed from the registry when calling unregister_all_services (#644) @bdraco There was a race condition where a query could be answered for a service @@ -394,18 +358,14 @@ This release offers 100% line and branch coverage being broadcast after the goodbye if a query came in at just the right time. To avoid this, we now remove the services from the registry right after we generate the goodbye packet - * Ensure zeroconf can be loaded when the system disables IPv6 (#624) @bdraco - * Ensure the QU bit is set for probe queries (#609) @bdraco The bit should be set per datatracker.ietf.org/doc/html/rfc6762#section-8.1 * Breaking change: Update python compatibility as PyPy3 7.2 is required (#523) @bdraco - * Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco - * Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco Ensure that questions that exceed the max packet size are @@ -416,40 +376,27 @@ This release offers 100% line and branch coverage Ensure only one resource record is sent when a record exceeds _MAX_MSG_TYPICAL datatracker.ietf.org/doc/html/rfc6762#section-17 - * Ensure PTR questions asked in uppercase are answered (#465) @bdraco - * Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek - * Implement an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco - * Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco - * Allow passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco - * Fix IPv6 setup under MacOS when binding to "" (#392) @bdraco - * Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco There was a short window where the ServiceBrowser thread could be left running after Zeroconf is closed because the .join() was never waited for when a new Zeroconf object was created - * Ensure duplicate packets do not trigger duplicate updates (#376) @bdraco If TXT or SRV records update was already processed and then recieved again, it was possible for a second update to be called back in the ServiceBrowser - * Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) @bdraco - * Fix RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco - * Reduce length of ServiceBrowser thread name with many types (#373) @bdraco - * Fix empty answers being added in ServiceInfo.request (#367) @bdraco - * Ensure ServiceInfo populates all AAAA records (#366) @bdraco Use get_all_by_details to ensure all records are loaded @@ -460,7 +407,6 @@ This release offers 100% line and branch coverage Move duplicate code that checked if the ServiceInfo was complete into its own function - * Add new cache function get_all_by_details (#363) @bdraco When working with IPv6, multiple AAAA records can exist for a given host. get_by_details would only return the @@ -468,7 +414,6 @@ This release offers 100% line and branch coverage Fix a case where the cache list can change during iteration - * Return task objects created by AsyncZeroconf (#360) @nocarryr 0.31.0 From 9abb40cf331bc0acc5fdbb03fce5c958cec8b41e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 09:13:05 -1000 Subject: [PATCH 499/608] Reformat backwards incompatible changes to match previous versions (#889) --- README.rst | 122 ++++++++++++++++++++++++----------------------------- 1 file changed, 56 insertions(+), 66 deletions(-) diff --git a/README.rst b/README.rst index 4b39b591..44a370f4 100644 --- a/README.rst +++ b/README.rst @@ -143,10 +143,6 @@ Changelog 0.32.0 (Unreleased) =================== -Documentation for breaking changes era on the side of the caution and likely -overstates the risk on many of these. If you are not accessing zeroconf internals, -you can likely not be concerned with the breaking changes below: - This release offers 100% line and branch coverage * Make ServiceInfo first question QU (#852) @bdraco @@ -180,18 +176,6 @@ This release offers 100% line and branch coverage interfaces (usually and wifi and ethernet), this would confuse the multi-packet known answer supression since it was not expecting to get the same data more than once -* BREAKING CHANGE: Drop oversize packets before processing them (#826) @bdraco - - Oversized packets can quickly overwhelm the system and deny - service to legitimate queriers. In practice this is usually - due to broken mDNS implementations rather than malicious - actors. -* BREAKING CHANGE: Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) @bdraco - - We now enforce a minimum TTL for PTR records to avoid - ServiceBrowsers generating excessive queries refresh queries. - Apple uses a 15s minimum TTL, however we do not have the same - level of rate limit and safe guards so we use 1/4 of the recommended value. * New ServiceBrowsers now request QU in the first outgoing when unspecified (#812) @bdraco https://datatracker.ietf.org/doc/html/rfc6762#section-5.4 @@ -213,61 +197,14 @@ This release offers 100% line and branch coverage return qualified addresses to avoid breaking compatibility on the existing parsed_addresses(). * Skip network adapters that are disconnected (#327) @ZLJasonG -* Pass both the new and old records to async_update_records (#792) @bdraco - - Pass the old_record (cached) as the value and the new_record (wire) - to async_update_records instead of forcing each consumer to - check the cache since we will always have the old_record - when generating the async_update_records call. This avoids - the overhead of multiple cache lookups for each listener. -* BREAKING CHANGE: Update internal version check to match docs (3.6+) (#491) @bdraco - - Python version eariler then 3.6 were likely broken with zeroconf - already, however the version is now explictly checked. -* BREAKING CHANGE: RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco - - This allows the listener to receive all the records that have - been updated in a single transaction such as a packet or - cache expiry. - - update_record has been deprecated in favor of async_update_records - A compatibility shim exists to ensure classes that use - RecordUpdateListener as a base class continue to have - update_record called, however they should be updated - as soon as possible. - - A new method async_update_records_complete is now called on each - listener when all listeners have completed processing updates - and the cache has been updated. This allows ServiceBrowsers - to delay calling handlers until they are sure the cache - has been updated as its a common pattern to call for - ServiceInfo when a ServiceBrowser handler fires. - - The async_ prefix was choosen to make it clear that these - functions run in the eventloop and should never do blocking - I/O. Before 0.32+ these functions ran in a select() loop and - should not have been doing any blocking I/O, but it was not - clear to implementors that I/O would block the loop. -* BREAKING CHANGE: Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco +* Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco When manually creating a zeroconf.Engine object, it is no longer started automatically. It must manually be started by calling .start() on the created object. The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. -* BREAKING CHANGE: Remove DNSOutgoing.packet backwards compatibility (#569) @bdraco - - DNSOutgoing.packet only returned a partial message when the - DNSOutgoing contents exceeded _MAX_MSG_ABSOLUTE or _MAX_MSG_TYPICAL - This was a legacy function that was replaced with .packets() - which always returns a complete payload in #248 As packet() - should not be used since it will end up missing data, it has - been removed -* BREAKING CHANGE: Mark DNSOutgoing write functions as protected (#633) @bdraco - - These functions are not intended to be used by external - callers and the API is not likely to be stable in the future -* BREAKING CHANGE: Prefix cache functions that are non threadsafe with async_ (#724) @bdraco +* Prefix cache functions that are non threadsafe with async_ (#724) @bdraco Adding (`zc.cache.add` -> `zc.cache.async_add_records`), removing (`zc.cache.remove` -> `zc.cache.async_remove_records`), and expiring the cache (`zc.cache.expire` -> @@ -364,7 +301,6 @@ This release offers 100% line and branch coverage The bit should be set per datatracker.ietf.org/doc/html/rfc6762#section-8.1 -* Breaking change: Update python compatibility as PyPy3 7.2 is required (#523) @bdraco * Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco * Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco @@ -416,6 +352,60 @@ This release offers 100% line and branch coverage iteration * Return task objects created by AsyncZeroconf (#360) @nocarryr +Technically backwards incompatible: + +* Update internal version check to match docs (3.6+) (#491) @bdraco + + Python version eariler then 3.6 were likely broken with zeroconf + already, however the version is now explictly checked. +* Update python compatibility as PyPy3 7.2 is required (#523) @bdraco + +Backwards incompatible: + +* Drop oversize packets before processing them (#826) @bdraco + + Oversized packets can quickly overwhelm the system and deny + service to legitimate queriers. In practice this is usually + due to broken mDNS implementations rather than malicious + actors. +* Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) @bdraco + + We now enforce a minimum TTL for PTR records to avoid + ServiceBrowsers generating excessive queries refresh queries. + Apple uses a 15s minimum TTL, however we do not have the same + level of rate limit and safe guards so we use 1/4 of the recommended value. +* RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco + + This allows the listener to receive all the records that have + been updated in a single transaction such as a packet or + cache expiry. + + update_record has been deprecated in favor of async_update_records + A compatibility shim exists to ensure classes that use + RecordUpdateListener as a base class continue to have + update_record called, however they should be updated + as soon as possible. + + A new method async_update_records_complete is now called on each + listener when all listeners have completed processing updates + and the cache has been updated. This allows ServiceBrowsers + to delay calling handlers until they are sure the cache + has been updated as its a common pattern to call for + ServiceInfo when a ServiceBrowser handler fires. + + The async_ prefix was choosen to make it clear that these + functions run in the eventloop and should never do blocking + I/O. Before 0.32+ these functions ran in a select() loop and + should not have been doing any blocking I/O, but it was not + clear to implementors that I/O would block the loop. +* Pass both the new and old records to async_update_records (#792) @bdraco + + Pass the old_record (cached) as the value and the new_record (wire) + to async_update_records instead of forcing each consumer to + check the cache since we will always have the old_record + when generating the async_update_records call. This avoids + the overhead of multiple cache lookups for each listener. + 0.31.0 ====== From 0d911568d367f1520acb19bdf830fe188b6ffb70 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 09:25:48 -1000 Subject: [PATCH 500/608] Rewrite 0.32.0 changelog in past tense (#890) --- README.rst | 111 +++++++++++++++++++++++------------------------------ 1 file changed, 47 insertions(+), 64 deletions(-) diff --git a/README.rst b/README.rst index 44a370f4..288c6358 100644 --- a/README.rst +++ b/README.rst @@ -145,7 +145,7 @@ Changelog This release offers 100% line and branch coverage -* Make ServiceInfo first question QU (#852) @bdraco +* Made ServiceInfo first question QU (#852) @bdraco We want an immediate response when making a request with ServiceInfo by asking a QU question, most responders will not delay the response @@ -163,13 +163,13 @@ This release offers 100% line and branch coverage This change puts ServiceInfo inline with ServiceBrowser which also asks the first question as QU since ServiceInfo is commonly called from ServiceBrowser callbacks -* Limit duplicate packet suppression to 1s intervals (#841) @bdraco +* Limited duplicate packet suppression to 1s intervals (#841) @bdraco Only suppress duplicate packets that happen within the same second. Legitimate queriers will retry the question if they are suppressed. The limit was reduced to one second to be in line with rfc6762 -* Make multipacket known answer suppression per interface (#836) @bdraco +* Made multipacket known answer suppression per interface (#836) @bdraco The suppression was happening per instance of Zeroconf instead of per interface. Since the same network can be seen on multiple @@ -186,7 +186,7 @@ This release offers 100% line and branch coverage a breaking change to increase). This reduces the amount of traffic on the network, and has the secondary advantage that most responders will answer a QU question without the typical delay answering QM questions. -* Qualify IPv6 link-local addresses with scope_id (#343) @ibygrave +* IPv6 link-local addresses are now qualified with scope_id (#343) @ibygrave When a service is advertised on an IPv6 address where the scope is link local, i.e. fe80::/64 (see RFC 4007) @@ -196,47 +196,35 @@ This release offers 100% line and branch coverage A new API `parsed_scoped_addresses()` is provided to return qualified addresses to avoid breaking compatibility on the existing parsed_addresses(). -* Skip network adapters that are disconnected (#327) @ZLJasonG -* Ensure listeners do not miss initial packets if Engine starts too quickly (#387) @bdraco +* Network adapters that are disconnected are now skipped (#327) @ZLJasonG +* Fixed listeners missing initial packets if Engine starts too quickly (#387) @bdraco When manually creating a zeroconf.Engine object, it is no longer started automatically. It must manually be started by calling .start() on the created object. The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. -* Prefix cache functions that are non threadsafe with async_ (#724) @bdraco - - Adding (`zc.cache.add` -> `zc.cache.async_add_records`), removing (`zc.cache.remove` -> - `zc.cache.async_remove_records`), and expiring the cache (`zc.cache.expire` -> - `zc.cache.async_expire`) the cache is not threadsafe and must be called from the - event loop (previously the Engine select loop before 0.32) - - These functions should only be run from the event loop as they are NOT thread safe. - - We never expect these functions will be called externally, however it was possible so this - is documented as a breaking change. It is highly recommended that external callers do not - modify the cache directly. -* TRAFFIC REDUCTION: Add support for handling QU questions (#621) @bdraco +* TRAFFIC REDUCTION: Added support for handling QU questions (#621) @bdraco Implements RFC 6762 sec 5.4: Questions Requesting Unicast Responses datatracker.ietf.org/doc/html/rfc6762#section-5.4 -* TRAFFIC REDUCTION: Protect the network against excessive packet flooding (#619) @bdraco -* TRAFFIC REDUCTION: Suppress additionals when they are already in the answers section (#617) @bdraco -* TRAFFIC REDUCTION: Avoid including additionals when the answer is suppressed by known-answer supression (#614) @bdraco -* TRAFFIC REDUCTION: Implement multi-packet known answer supression (#687) @bdraco +* TRAFFIC REDUCTION: Implemented protect the network against excessive packet flooding (#619) @bdraco +* TRAFFIC REDUCTION: Additionals are now suppressed when they are already in the answers section (#617) @bdraco +* TRAFFIC REDUCTION: Additionals are no longer included when the answer is suppressed by known-answer supression (#614) @bdraco +* TRAFFIC REDUCTION: Implemented multi-packet known answer supression (#687) @bdraco Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 -* TRAFFIC REDUCTION: Efficiently bucket queries with known answers (#698) @bdraco -* TRAFFIC REDUCTION: Implement duplicate question supression (#770) @bdraco +* TRAFFIC REDUCTION: Implemented efficent bucketing of queries with known answers (#698) @bdraco +* TRAFFIC REDUCTION: Implemented duplicate question supression (#770) @bdraco http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 -* MAJOR BUG: Ensure matching PTR queries are returned with the ANY query (#618) @bdraco -* MAJOR BUG: Fix lookup of uppercase names in registry (#597) @bdraco +* MAJOR BUG: Fixed answering matching PTR queries with the ANY query (#618) @bdraco +* MAJOR BUG: Fixed lookup of uppercase names in registry (#597) @bdraco If the ServiceInfo was registered with an uppercase name and the query was for a lowercase name, it would not be found and vice-versa. -* MAJOR BUG: Ensure unicast responses can be sent to any source port (#598) @bdraco +* MAJOR BUG: Fixed unicast responses from any source port (#598) @bdraco Unicast responses were only being sent if the source port was 53, this prevented responses when testing with dig: @@ -244,19 +232,19 @@ This release offers 100% line and branch coverage dig -p 5353 @224.0.0.251 media-12.local The above query will now see a response -* MAJOR BUG: Fix queries for AAAA records (#616) @bdraco -* Remove second level caching from ServiceBrowsers (#737) @bdraco +* MAJOR BUG: Fixed queries for AAAA records not being answered (#616) @bdraco +* Removed second level caching from ServiceBrowsers (#737) @bdraco The ServiceBrowser had its own cache of the last time it saw a service which was reimplementing the DNSCache and presenting a source of truth problem that lead to unexpected queries when the two disagreed. -* Fix server cache to be case-insensitive (#731) @bdraco +* Fixed server cache not being case-insensitive (#731) @bdraco If the server name had uppercase chars and any of the matching records were lowercase, the server would not be found -* Fix cache handling of records with different TTLs (#729) @bdraco +* Fixed cache handling of records with different TTLs (#729) @bdraco There should only be one unique record in the cache at a time as having multiple unique records will different @@ -267,12 +255,14 @@ This release offers 100% line and branch coverage to ensure that the newest record always replaces the same unique record and we never have a source of truth problem determining the TTL of a record from the cache. -* Fix ServiceInfo with multiple A records (#725) @bdraco +* Fixed ServiceInfo with multiple A records (#725) @bdraco If there were multiple A records for the host, ServiceInfo would always return the last one that was in the incoming packet which was usually not the one that was wanted. -* Set stale unique records to expire 1s in the future instead of instant removal (#706) @bdraco +* Fixed stale unique records expiring too quickly (#706) @bdraco + + Recods now expire 1s in the future instead of instant removal. tools.ietf.org/html/rfc6762#section-10.2 Queriers receiving a Multicast DNS response with a TTL of zero SHOULD @@ -283,26 +273,25 @@ This release offers 100% line and branch coverage incorrectly sends goodbye packets for its records, it gives the other cooperating responders one second to send out their own response to "rescue" the records before they expire and are deleted. -* Suppress additionals when answer is suppressed (#690) @bdraco -* Allow unregistering a service multiple times (#679) @bdraco -* Add an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco -* Ensure interface_index_to_ip6_address skips ipv4 adapters (#651) @bdraco -* Add async_unregister_all_services to AsyncZeroconf (#649) @bdraco -* Ensure services are removed from the registry when calling unregister_all_services (#644) @bdraco +* Fixed exception when unregistering a service multiple times (#679) @bdraco +* Added an AsyncZeroconfServiceTypes to mirror ZeroconfServiceTypes to zeroconf.asyncio (#658) @bdraco +* Fixed interface_index_to_ip6_address not skiping ipv4 adapters (#651) @bdraco +* Added async_unregister_all_services to AsyncZeroconf (#649) @bdraco +* Fixed services not being removed from the registry when calling unregister_all_services (#644) @bdraco There was a race condition where a query could be answered for a service in the registry while goodbye packets which could result a fresh record being broadcast after the goodbye if a query came in at just the right time. To avoid this, we now remove the services from the registry right after we generate the goodbye packet -* Ensure zeroconf can be loaded when the system disables IPv6 (#624) @bdraco -* Ensure the QU bit is set for probe queries (#609) @bdraco +* Fixed zeroconf exception on load when the system disables IPv6 (#624) @bdraco +* Fixed the QU bit missing from for probe queries (#609) @bdraco The bit should be set per datatracker.ietf.org/doc/html/rfc6762#section-8.1 -* Set the TC bit for query packets where the known answers span multiple packets (#494) @bdraco -* Ensure packets are properly seperated when exceeding maximum size (#498) @bdraco +* Fixed the TC bit mising for query packets where the known answers span multiple packets (#494) @bdraco +* Fixed packets not being properly seperated when exceeding maximum size (#498) @bdraco Ensure that questions that exceed the max packet size are moved to the next packet. This fixes DNSQuestions being @@ -312,28 +301,28 @@ This release offers 100% line and branch coverage Ensure only one resource record is sent when a record exceeds _MAX_MSG_TYPICAL datatracker.ietf.org/doc/html/rfc6762#section-17 -* Ensure PTR questions asked in uppercase are answered (#465) @bdraco -* Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek -* Implement an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco -* Add async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco -* Allow passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco -* Fix IPv6 setup under MacOS when binding to "" (#392) @bdraco -* Ensure ZeroconfServiceTypes.find always cancels the ServiceBrowser (#389) @bdraco +* Fixed PTR questions asked in uppercase not being answered (#465) @bdraco +* Added Support for context managers in Zeroconf and AsyncZeroconf (#284) @shenek +* Implemented an AsyncServiceBrowser to compliment the sync ServiceBrowser (#429) @bdraco +* Added async_get_service_info to AsyncZeroconf and async_request to AsyncServiceInfo (#408) @bdraco +* Implemented allowing passing in a sync Zeroconf instance to AsyncZeroconf (#406) @bdraco +* Fixed IPv6 setup under MacOS when binding to "" (#392) @bdraco +* Fixed ZeroconfServiceTypes.find not always cancels the ServiceBrowser (#389) @bdraco There was a short window where the ServiceBrowser thread could be left running after Zeroconf is closed because the .join() was never waited for when a new Zeroconf object was created -* Ensure duplicate packets do not trigger duplicate updates (#376) @bdraco +* Fixed duplicate packets triggering duplicate updates (#376) @bdraco If TXT or SRV records update was already processed and then recieved again, it was possible for a second update to be called back in the ServiceBrowser -* Only trigger a ServiceStateChange.Updated event when an ip address is added (#375) @bdraco -* Fix RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco -* Reduce length of ServiceBrowser thread name with many types (#373) @bdraco -* Fix empty answers being added in ServiceInfo.request (#367) @bdraco -* Ensure ServiceInfo populates all AAAA records (#366) @bdraco +* Fixed ServiceStateChange.Updated event happening for IPs that already existed (#375) @bdraco +* Fixed RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco +* Reduced length of ServiceBrowser thread name with many types (#373) @bdraco +* Fixed empty answers being added in ServiceInfo.request (#367) @bdraco +* Fixed ServiceInfo not populating all AAAA records (#366) @bdraco Use get_all_by_details to ensure all records are loaded into addresses. @@ -343,13 +332,7 @@ This release offers 100% line and branch coverage Move duplicate code that checked if the ServiceInfo was complete into its own function -* Add new cache function get_all_by_details (#363) @bdraco - When working with IPv6, multiple AAAA records can exist - for a given host. get_by_details would only return the - latest record in the cache. - - Fix a case where the cache list can change during - iteration +* Fixed a case where the cache list can change during iteration (#363) @bdraco * Return task objects created by AsyncZeroconf (#360) @nocarryr Technically backwards incompatible: From ba235dd8bc65de4f461f76fd2bf4647844437e1a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 09:34:03 -1000 Subject: [PATCH 501/608] Fix spelling and grammar errors in 0.32.0 changelog (#891) --- README.rst | 59 +++++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/README.rst b/README.rst index 288c6358..8f8a36a6 100644 --- a/README.rst +++ b/README.rst @@ -143,22 +143,22 @@ Changelog 0.32.0 (Unreleased) =================== -This release offers 100% line and branch coverage +This release offers 100% line and branch coverage. * Made ServiceInfo first question QU (#852) @bdraco - We want an immediate response when making a request with ServiceInfo - by asking a QU question, most responders will not delay the response + We want an immediate response when requesting with ServiceInfo + by asking a QU question; most responders will not delay the response and respond right away to our question. This also improves compatibility with split networks as we may not have been able to see the response - otherwise. If the responder has not multicast the record recently + otherwise. If the responder has not multicast the record recently, it may still choose to do so in addition to responding via unicast Reduces traffic when there are multiple zeroconf instances running on the network running ServiceBrowsers If we don't get an answer on the first try, we ask a QM question - in the event we can't receive a unicast response for some reason + in the event, we can't receive a unicast response for some reason This change puts ServiceInfo inline with ServiceBrowser which also asks the first question as QU since ServiceInfo is commonly @@ -211,16 +211,16 @@ This release offers 100% line and branch coverage datatracker.ietf.org/doc/html/rfc6762#section-5.4 * TRAFFIC REDUCTION: Implemented protect the network against excessive packet flooding (#619) @bdraco * TRAFFIC REDUCTION: Additionals are now suppressed when they are already in the answers section (#617) @bdraco -* TRAFFIC REDUCTION: Additionals are no longer included when the answer is suppressed by known-answer supression (#614) @bdraco +* TRAFFIC REDUCTION: Additionals are no longer included when the answer is suppressed by known-answer suppression (#614) @bdraco * TRAFFIC REDUCTION: Implemented multi-packet known answer supression (#687) @bdraco Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 -* TRAFFIC REDUCTION: Implemented efficent bucketing of queries with known answers (#698) @bdraco -* TRAFFIC REDUCTION: Implemented duplicate question supression (#770) @bdraco +* TRAFFIC REDUCTION: Implemented efficient bucketing of queries with known answers (#698) @bdraco +* TRAFFIC REDUCTION: Implemented duplicate question suppression (#770) @bdraco http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 * MAJOR BUG: Fixed answering matching PTR queries with the ANY query (#618) @bdraco -* MAJOR BUG: Fixed lookup of uppercase names in registry (#597) @bdraco +* MAJOR BUG: Fixed lookup of uppercase names in the registry (#597) @bdraco If the ServiceInfo was registered with an uppercase name and the query was for a lowercase name, it would not be found and vice-versa. @@ -236,13 +236,13 @@ This release offers 100% line and branch coverage * Removed second level caching from ServiceBrowsers (#737) @bdraco The ServiceBrowser had its own cache of the last time it - saw a service which was reimplementing the DNSCache and + saw a service that was reimplementing the DNSCache and presenting a source of truth problem that lead to unexpected queries when the two disagreed. * Fixed server cache not being case-insensitive (#731) @bdraco If the server name had uppercase chars and any of the - matching records were lowercase, the server would not be + matching records were lowercase, and the server would not be found * Fixed cache handling of records with different TTLs (#729) @bdraco @@ -251,18 +251,18 @@ This release offers 100% line and branch coverage TTLs in the cache can result in unexpected behavior since some functions returned all matching records and some fetched from the right side of the list to return the - newest record. Intead we now store the records in a dict + newest record. Instead we now store the records in a dict to ensure that the newest record always replaces the same - unique record and we never have a source of truth problem + unique record, and we never have a source of truth problem determining the TTL of a record from the cache. * Fixed ServiceInfo with multiple A records (#725) @bdraco If there were multiple A records for the host, ServiceInfo would always return the last one that was in the incoming - packet which was usually not the one that was wanted. + packet, which was usually not the one that was wanted. * Fixed stale unique records expiring too quickly (#706) @bdraco - Recods now expire 1s in the future instead of instant removal. + Records now expire 1s in the future instead of instant removal. tools.ietf.org/html/rfc6762#section-10.2 Queriers receiving a Multicast DNS response with a TTL of zero SHOULD @@ -280,7 +280,7 @@ This release offers 100% line and branch coverage * Fixed services not being removed from the registry when calling unregister_all_services (#644) @bdraco There was a race condition where a query could be answered for a service - in the registry while goodbye packets which could result a fresh record + in the registry, while goodbye packets which could result in a fresh record being broadcast after the goodbye if a query came in at just the right time. To avoid this, we now remove the services from the registry right after we generate the goodbye packet @@ -290,8 +290,8 @@ This release offers 100% line and branch coverage The bit should be set per datatracker.ietf.org/doc/html/rfc6762#section-8.1 -* Fixed the TC bit mising for query packets where the known answers span multiple packets (#494) @bdraco -* Fixed packets not being properly seperated when exceeding maximum size (#498) @bdraco +* Fixed the TC bit missing for query packets where the known answers span multiple packets (#494) @bdraco +* Fixed packets not being properly separated when exceeding maximum size (#498) @bdraco Ensure that questions that exceed the max packet size are moved to the next packet. This fixes DNSQuestions being @@ -316,7 +316,7 @@ This release offers 100% line and branch coverage * Fixed duplicate packets triggering duplicate updates (#376) @bdraco If TXT or SRV records update was already processed and then - recieved again, it was possible for a second update to be + received again, it was possible for a second update to be called back in the ServiceBrowser * Fixed ServiceStateChange.Updated event happening for IPs that already existed (#375) @bdraco * Fixed RFC6762 Section 10.2 paragraph 2 compliance (#374) @bdraco @@ -327,7 +327,7 @@ This release offers 100% line and branch coverage Use get_all_by_details to ensure all records are loaded into addresses. - Only load A/AAAA records from cache once in load_from_cache + Only load A/AAAA records from the cache once in load_from_cache if there is a SRV record present Move duplicate code that checked if the ServiceInfo was complete @@ -339,8 +339,8 @@ Technically backwards incompatible: * Update internal version check to match docs (3.6+) (#491) @bdraco - Python version eariler then 3.6 were likely broken with zeroconf - already, however the version is now explictly checked. + Python version earlier then 3.6 were likely broken with zeroconf + already, however, the version is now explicitly checked. * Update python compatibility as PyPy3 7.2 is required (#523) @bdraco Backwards incompatible: @@ -348,15 +348,14 @@ Backwards incompatible: * Drop oversize packets before processing them (#826) @bdraco Oversized packets can quickly overwhelm the system and deny - service to legitimate queriers. In practice this is usually - due to broken mDNS implementations rather than malicious - actors. -* Guard against excessive ServiceBrowser queries from PTR records significantly lower than recommended (#824) @bdraco + service to legitimate queriers. In practice, this is usually due to broken mDNS + implementations rather than malicious actors. +* Guard against excessive ServiceBrowser queries from PTR records significantly lowerthan recommended (#824) @bdraco We now enforce a minimum TTL for PTR records to avoid ServiceBrowsers generating excessive queries refresh queries. - Apple uses a 15s minimum TTL, however we do not have the same - level of rate limit and safe guards so we use 1/4 of the recommended value. + Apple uses a 15s minimum TTL, however, we do not have the same + level of rate limit and safeguards, so we use 1/4 of the recommended value. * RecordUpdateListener now uses async_update_records instead of update_record (#419, #726) @bdraco This allows the listener to receive all the records that have @@ -366,7 +365,7 @@ Backwards incompatible: update_record has been deprecated in favor of async_update_records A compatibility shim exists to ensure classes that use RecordUpdateListener as a base class continue to have - update_record called, however they should be updated + update_record called, however, they should be updated as soon as possible. A new method async_update_records_complete is now called on each @@ -376,7 +375,7 @@ Backwards incompatible: has been updated as its a common pattern to call for ServiceInfo when a ServiceBrowser handler fires. - The async_ prefix was choosen to make it clear that these + The async_ prefix was chosen to make it clear that these functions run in the eventloop and should never do blocking I/O. Before 0.32+ these functions ran in a select() loop and should not have been doing any blocking I/O, but it was not From 34f6e498dec18b84dab1c27c75348916bceef8e6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 29 Jun 2021 09:37:34 -1000 Subject: [PATCH 502/608] Reformat changelog to match prior versions (#892) --- README.rst | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/README.rst b/README.rst index 8f8a36a6..54e20e27 100644 --- a/README.rst +++ b/README.rst @@ -204,27 +204,12 @@ This release offers 100% line and branch coverage. The Engine thread is now started after all the listeners have been added to avoid a race condition where packets could be missed at startup. -* TRAFFIC REDUCTION: Added support for handling QU questions (#621) @bdraco - - Implements RFC 6762 sec 5.4: - Questions Requesting Unicast Responses - datatracker.ietf.org/doc/html/rfc6762#section-5.4 -* TRAFFIC REDUCTION: Implemented protect the network against excessive packet flooding (#619) @bdraco -* TRAFFIC REDUCTION: Additionals are now suppressed when they are already in the answers section (#617) @bdraco -* TRAFFIC REDUCTION: Additionals are no longer included when the answer is suppressed by known-answer suppression (#614) @bdraco -* TRAFFIC REDUCTION: Implemented multi-packet known answer supression (#687) @bdraco - - Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 -* TRAFFIC REDUCTION: Implemented efficient bucketing of queries with known answers (#698) @bdraco -* TRAFFIC REDUCTION: Implemented duplicate question suppression (#770) @bdraco - - http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 -* MAJOR BUG: Fixed answering matching PTR queries with the ANY query (#618) @bdraco -* MAJOR BUG: Fixed lookup of uppercase names in the registry (#597) @bdraco +* Fixed answering matching PTR queries with the ANY query (#618) @bdraco +* Fixed lookup of uppercase names in the registry (#597) @bdraco If the ServiceInfo was registered with an uppercase name and the query was for a lowercase name, it would not be found and vice-versa. -* MAJOR BUG: Fixed unicast responses from any source port (#598) @bdraco +* Fixed unicast responses from any source port (#598) @bdraco Unicast responses were only being sent if the source port was 53, this prevented responses when testing with dig: @@ -232,7 +217,7 @@ This release offers 100% line and branch coverage. dig -p 5353 @224.0.0.251 media-12.local The above query will now see a response -* MAJOR BUG: Fixed queries for AAAA records not being answered (#616) @bdraco +* Fixed queries for AAAA records not being answered (#616) @bdraco * Removed second level caching from ServiceBrowsers (#737) @bdraco The ServiceBrowser had its own cache of the last time it @@ -335,6 +320,24 @@ This release offers 100% line and branch coverage. * Fixed a case where the cache list can change during iteration (#363) @bdraco * Return task objects created by AsyncZeroconf (#360) @nocarryr +Traffic Reduction: + +* Added support for handling QU questions (#621) @bdraco + + Implements RFC 6762 sec 5.4: + Questions Requesting Unicast Responses + datatracker.ietf.org/doc/html/rfc6762#section-5.4 +* Implemented protect the network against excessive packet flooding (#619) @bdraco +* Additionals are now suppressed when they are already in the answers section (#617) @bdraco +* Additionals are no longer included when the answer is suppressed by known-answer suppression (#614) @bdraco +* Implemented multi-packet known answer supression (#687) @bdraco + + Implements datatracker.ietf.org/doc/html/rfc6762#section-7.2 +* Implemented efficient bucketing of queries with known answers (#698) @bdraco +* Implemented duplicate question suppression (#770) @bdraco + + http://datatracker.ietf.org/doc/html/rfc6762#section-7.3 + Technically backwards incompatible: * Update internal version check to match docs (3.6+) (#491) @bdraco From ea7bc8592e418332e5b9973007698d3cd79754d9 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 30 Jun 2021 03:23:58 +0200 Subject: [PATCH 503/608] Release version 0.32.0 --- README.rst | 4 ++-- zeroconf/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 54e20e27..af33baad 100644 --- a/README.rst +++ b/README.rst @@ -140,8 +140,8 @@ See examples directory for more. Changelog ========= -0.32.0 (Unreleased) -=================== +0.32.0 +====== This release offers 100% line and branch coverage. diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e3bb987f..b39d7436 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.31.0' +__version__ = '0.32.0' __license__ = 'LGPL' From 82ff150e0a72a7e20823a0c805f48f117bf1e274 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Wed, 30 Jun 2021 03:29:06 +0200 Subject: [PATCH 504/608] Fix readme formatting It wasn't proper reStructuredText before: % twine check dist/* Checking dist/zeroconf-0.32.0-py3-none-any.whl: FAILED `long_description` has syntax errors in markup and would not be rendered on PyPI. line 381: Error: Unknown target name: "async". warning: `long_description_content_type` missing. defaulting to `text/x-rst`. Checking dist/zeroconf-0.32.0.tar.gz: FAILED `long_description` has syntax errors in markup and would not be rendered on PyPI. line 381: Error: Unknown target name: "async". warning: `long_description_content_type` missing. defaulting to `text/x-rst`. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index af33baad..2b38fa7a 100644 --- a/README.rst +++ b/README.rst @@ -378,7 +378,7 @@ Backwards incompatible: has been updated as its a common pattern to call for ServiceInfo when a ServiceBrowser handler fires. - The async_ prefix was chosen to make it clear that these + The async\_ prefix was chosen to make it clear that these functions run in the eventloop and should never do blocking I/O. Before 0.32+ these functions ran in a select() loop and should not have been doing any blocking I/O, but it was not From 90bc8ca8dce1af26ea81c5d6ecb17cf6ea664a71 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Sun, 4 Jul 2021 15:15:22 +0100 Subject: [PATCH 505/608] Add test for running sync code within executor (#894) --- tests/test_asyncio.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 759ab5b3..2e504b68 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -101,6 +101,17 @@ async def test_async_with_sync_passed_in_closed_in_async() -> None: await aiozc.async_close() +@pytest.mark.asyncio +async def test_sync_within_event_loop_executor() -> None: + """Test sync version still works from an executor within an event loop.""" + def sync_code(): + zc = Zeroconf(interfaces=['127.0.0.1']) + assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None + zc.close() + + await asyncio.get_event_loop().run_in_executor(None, sync_code) + + @pytest.mark.asyncio async def test_async_service_registration() -> None: """Test registering services broadcasts the registration by default.""" From 56c7d692d67b7f56c386a7f1f4e45ebfc4e8366a Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 4 Jul 2021 09:16:02 -0500 Subject: [PATCH 506/608] Increase timeout in ServiceInfo.request to handle loaded systems (#895) It can take a few seconds for a loaded system to run the `async_request` coroutine when the event loop is busy or the system is CPU bound (example being Home Assistant startup). We now add an additional `_LOADED_SYSTEM_TIMEOUT` (10s) to the `run_coroutine_threadsafe` calls to ensure the coroutine has the total amount of time to run up to its internal timeout (default of 3000ms). Ten seconds is a bit large of a timeout; however, its only unused in cases where we wrap other timeouts. We now expect the only instance the `run_coroutine_threadsafe` result timeout will happen in a production circumstance is when someone is running a `ServiceInfo.request()` in a thread and another thread calls `Zeroconf.close()` at just the right moment that the future is never completed unless the system is so loaded that it is nearly unresponsive. The timeout for `run_coroutine_threadsafe` is the maximum time a thread can cleanly shut down when zeroconf is closed out in another thread, which should always be longer than the underlying thread operation. --- tests/services/test_info.py | 12 ++++++++++++ tests/test_asyncio.py | 14 ++++++++++++++ tests/utils/test_asyncio.py | 17 ++++++++++++++++- zeroconf/_core.py | 5 ++++- zeroconf/_services/info.py | 3 ++- zeroconf/_utils/asyncio.py | 4 ++-- zeroconf/const.py | 6 ++++++ 7 files changed, 56 insertions(+), 5 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 02ba581e..8ac8beda 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -739,3 +739,15 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): zeroconf.get_service_info(f"name.{type_}", type_, 500, question_type=r.DNSQuestionType.QM) assert first_outgoing.questions[0].unicast == False zeroconf.close() + + +def test_request_timeout(): + """Test that the timeout does not throw an exception and finishes close to the actual timeout.""" + zeroconf = r.Zeroconf(interfaces=['127.0.0.1']) + start_time = r.current_time_millis() + assert zeroconf.get_service_info("_notfound.local.", "notthere._notfound.local.") is None + end_time = r.current_time_millis() + zeroconf.close() + # 3000ms for the default timeout + # 1000ms for loaded systems + schedule overhead + assert (end_time - start_time) < 3000 + 1000 diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 2e504b68..f4722389 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -935,3 +935,17 @@ def update_service(self, zc, type_, name) -> None: ('add', type_, registration_name), ] await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_async_request_timeout(): + """Test that the timeout does not throw an exception and finishes close to the actual timeout.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + start_time = current_time_millis() + assert await aiozc.async_get_service_info("_notfound.local.", "notthere._notfound.local.") is None + end_time = current_time_millis() + await aiozc.async_close() + # 3000ms for the default timeout + # 1000ms for loaded systems + schedule overhead + assert (end_time - start_time) < 3000 + 1000 diff --git a/tests/utils/test_asyncio.py b/tests/utils/test_asyncio.py index ccafb72f..2939b5ab 100644 --- a/tests/utils/test_asyncio.py +++ b/tests/utils/test_asyncio.py @@ -5,6 +5,7 @@ """Unit tests for zeroconf._utils.asyncio.""" import asyncio +import concurrent.futures import contextlib import threading import time @@ -12,7 +13,9 @@ import pytest +from zeroconf._core import _CLOSE_TIMEOUT from zeroconf._utils import asyncio as aioutils +from zeroconf.const import _LOADED_SYSTEM_TIMEOUT @pytest.mark.asyncio @@ -82,7 +85,8 @@ async def _still_running(): def _run_coro() -> None: runcoro_thread_ready.set() - asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1) + with contextlib.suppress(concurrent.futures.TimeoutError): + asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1) runcoro_thread = threading.Thread(target=_run_coro, daemon=True) runcoro_thread.start() @@ -97,3 +101,14 @@ def _run_coro() -> None: assert loop.is_running() is False runcoro_thread.join() + + +def test_cumulative_timeouts_less_than_close_plus_buffer(): + """Test that the combined async timeouts are shorter than the close timeout with the buffer. + + We want to make sure that the close timeout is the one that gets + raised if something goes wrong. + """ + assert ( + aioutils._TASK_AWAIT_TIMEOUT + aioutils._GET_ALL_TASKS_TIMEOUT + aioutils._WAIT_FOR_LOOP_TASKS_TIMEOUT + ) < 1 + _CLOSE_TIMEOUT + _LOADED_SYSTEM_TIMEOUT diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a20e5639..2f5ef507 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -62,6 +62,7 @@ _FLAGS_AA, _FLAGS_QR_QUERY, _FLAGS_QR_RESPONSE, + _LOADED_SYSTEM_TIMEOUT, _MAX_MSG_ABSOLUTE, _MDNS_ADDR, _MDNS_ADDR6, @@ -172,7 +173,9 @@ def close(self) -> None: return if not self.loop.is_running(): return - asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result(_CLOSE_TIMEOUT) + asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result( + _CLOSE_TIMEOUT + _LOADED_SYSTEM_TIMEOUT + ) class AsyncListener(asyncio.Protocol, QuietLogger): diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 3e371b17..7bc81c8d 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -45,6 +45,7 @@ _DNS_OTHER_TTL, _FLAGS_QR_QUERY, _LISTENER_TIME, + _LOADED_SYSTEM_TIMEOUT, _TYPE_A, _TYPE_AAAA, _TYPE_PTR, @@ -427,7 +428,7 @@ def request( raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") return asyncio.run_coroutine_threadsafe( self.async_request(zc, timeout, question_type), zc.loop - ).result(millis_to_seconds(timeout) + 1) + ).result(millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT) async def async_request( self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py index 57c1fb18..c68c0f00 100644 --- a/zeroconf/_utils/asyncio.py +++ b/zeroconf/_utils/asyncio.py @@ -26,8 +26,8 @@ from typing import Any, List, Optional, Set, cast _TASK_AWAIT_TIMEOUT = 1 -_GET_ALL_TASKS_TIMEOUT = 1 -_WAIT_FOR_LOOP_TASKS_TIMEOUT = 2 # Must be larger than _TASK_AWAIT_TIMEOUT +_GET_ALL_TASKS_TIMEOUT = 3 +_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT def get_best_available_queue() -> queue.Queue: diff --git a/zeroconf/const.py b/zeroconf/const.py index 0f26d80a..afdcb2d4 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -34,6 +34,12 @@ _DUPLICATE_QUESTION_INTERVAL = _BROWSER_TIME - 1 # ms _BROWSER_BACKOFF_LIMIT = 3600 # s _CACHE_CLEANUP_INTERVAL = 10000 # ms +_LOADED_SYSTEM_TIMEOUT = 10 # s +# If the system is loaded or the event +# loop was blocked by another task that was doing I/O in the loop +# (shouldn't happen but it does in practice) we need to give +# a buffer timeout to ensure a coroutine can finish before +# the future times out # Some DNS constants From a93301d0fd493bf18147187bf8efed1a4ea02214 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 4 Jul 2021 09:32:32 -0500 Subject: [PATCH 507/608] Update changelog (#899) --- README.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/README.rst b/README.rst index 2b38fa7a..76fd7744 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,29 @@ See examples directory for more. Changelog ========= + +0.32.1 (Unreleased) +=================== + +* Increase timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco + + It can take a few seconds for a loaded system to run the `async_request` + coroutine when the event loop is busy, or the system is CPU bound (example being + Home Assistant startup). We now add an additional `_LOADED_SYSTEM_TIMEOUT` (10s) + to the `run_coroutine_threadsafe` calls to ensure the coroutine has the total + amount of time to run up to its internal timeout (default of 3000ms). + + Ten seconds is a bit large of a timeout; however, it is only used in cases + where we wrap other timeouts. We now expect the only instance the + `run_coroutine_threadsafe` result timeout will happen in a production + circumstance is when someone is running a `ServiceInfo.request()` in a thread and + another thread calls `Zeroconf.close()` at just the right moment that the future + is never completed unless the system is so loaded that it is nearly unresponsive. + + The timeout for `run_coroutine_threadsafe` is the maximum time a thread can + cleanly shut down when zeroconf is closed out in another thread, which should + always be longer than the underlying thread operation. + 0.32.0 ====== From fc089be1f412d991f44daeecd0944198d3a638a5 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 5 Jul 2021 09:43:30 +0200 Subject: [PATCH 508/608] Fix the changelog's one sentence's tense --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 76fd7744..10f28890 100644 --- a/README.rst +++ b/README.rst @@ -144,7 +144,7 @@ Changelog 0.32.1 (Unreleased) =================== -* Increase timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco +* Increased timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco It can take a few seconds for a loaded system to run the `async_request` coroutine when the event loop is busy, or the system is CPU bound (example being From 675fd6fc959e76e4e3690e5c7a02db269ca9ef60 Mon Sep 17 00:00:00 2001 From: Jakub Stasiak Date: Mon, 5 Jul 2021 09:45:11 +0200 Subject: [PATCH 509/608] Release version 0.32.1 --- README.rst | 4 ++-- zeroconf/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 10f28890..7cefb379 100644 --- a/README.rst +++ b/README.rst @@ -141,8 +141,8 @@ Changelog ========= -0.32.1 (Unreleased) -=================== +0.32.1 +====== * Increased timeout in ServiceInfo.request to handle loaded systems (#895) @bdraco diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b39d7436..56c61ff3 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.32.0' +__version__ = '0.32.1' __license__ = 'LGPL' From f8af0fb251938dcb410127b2af2b8b407989aa08 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 09:40:17 -1000 Subject: [PATCH 510/608] Disable N818 in flake8 (#905) - We cannot rename these exceptions now without a breaking change as they have existed for many years --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index e9dc052f..e208561b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,7 +5,7 @@ testpaths = tests show-source = 1 application-import-names=zeroconf max-line-length=110 -ignore=E203,W503 +ignore=E203,W503,N818 [mypy] ignore_missing_imports = true From e417fc0f5ed7eaa47a0dcaffdbc6fe335bfcc058 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 10:16:28 -1000 Subject: [PATCH 511/608] Reduce duplicate code between zeroconf.asyncio and zeroconf._core (#904) --- zeroconf/_core.py | 101 +++++++++++++++++++++++++------------ zeroconf/_utils/asyncio.py | 9 +++- zeroconf/asyncio.py | 54 ++++---------------- 3 files changed, 87 insertions(+), 77 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 2f5ef507..a8986211 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -29,7 +29,7 @@ import sys import threading from types import TracebackType # noqa # used in type hints -from typing import Dict, List, Optional, Tuple, Type, Union, cast +from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType @@ -43,7 +43,7 @@ from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener -from ._utils.asyncio import get_running_loop, shutdown_loop, wait_event_or_timeout +from ._utils.asyncio import await_awaitable, get_running_loop, shutdown_loop, wait_event_or_timeout from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -74,6 +74,7 @@ _TC_DELAY_RANDOM_INTERVAL = (400, 500) _CLOSE_TIMEOUT = 3 +_REGISTER_BROADCASTS = 3 class AsyncEngine: @@ -478,6 +479,27 @@ def register_service( allow_name_change: bool = False, cooperating_responders: bool = False, ) -> None: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service. The name of the service may be changed if needed to make + it unique on the network. Additionally multiple cooperating responders + can register the same service on the network for resilience + (if you want this behavior set `cooperating_responders` to `True`).""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe( + await_awaitable( + self.async_register_service(info, ttl, allow_name_change, cooperating_responders) + ), + self.loop, + ).result(millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + + async def async_register_service( + self, + info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, + cooperating_responders: bool = False, + ) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service. The name of the service may be changed if needed to make @@ -489,36 +511,34 @@ def register_service( # Setting TTLs via ServiceInfo is preferred info.host_ttl = ttl info.other_ttl = ttl - self.check_service(info, allow_name_change, cooperating_responders) + + await self.async_wait_for_start() + await self.async_check_service(info, allow_name_change, cooperating_responders) self.registry.add(info) - self._broadcast_service(info, _REGISTER_TIME, None) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) def update_service(self, info: ServiceInfo) -> None: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service.""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe(await_awaitable(self.async_update_service(info)), self.loop).result( + millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT + ) + async def async_update_service(self, info: ServiceInfo) -> Awaitable: + """Registers service information to the network with a default TTL. + Zeroconf will then respond to requests for information for that + service.""" self.registry.update(info) - self._broadcast_service(info, _REGISTER_TIME, None) + return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: """Send a broadcasts to announce a service at intervals.""" - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - - self.send_service_broadcast(info, ttl) - i += 1 - next_time += interval - - def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None: - """Send a broadcast to announce a service.""" - self.send(self.generate_service_broadcast(info, ttl)) + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(interval)) + self.async_send(self.generate_service_broadcast(info, ttl)) def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: """Generate a broadcast to announce a service.""" @@ -526,10 +546,6 @@ def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> D self._add_broadcast_answer(out, info, ttl) return out - def send_service_query(self, info: ServiceInfo) -> None: - """Send a query to lookup a service.""" - self.send(self.generate_service_query(info)) - def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use """Generate a query to lookup a service.""" out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA) @@ -559,9 +575,16 @@ def _add_broadcast_answer( # pylint: disable=no-self-use out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: + """Unregister a service.""" + assert self.loop is not None + asyncio.run_coroutine_threadsafe( + await_awaitable(self.async_unregister_service(info)), self.loop + ).result(millis_to_seconds(_UNREGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + + async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service.""" self.registry.remove(info) - self._broadcast_service(info, _UNREGISTER_TIME, 0) + return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" @@ -574,6 +597,22 @@ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: self.registry.remove(service_infos) return out + async def async_unregister_all_services(self) -> None: + """Unregister all registered services. + + Unlike async_register_service and async_unregister_service, this + method does not return a future and is always expected to be + awaited since its only called at shutdown. + """ + # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 + out = self.generate_unregister_all_services() + if not out: + return + for i in range(_REGISTER_BROADCASTS): + if i != 0: + await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) + self.async_send(out) + def unregister_all_services(self) -> None: """Unregister all registered services.""" # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 @@ -592,7 +631,7 @@ def unregister_all_services(self) -> None: i += 1 next_time += _UNREGISTER_TIME - def check_service( + async def async_check_service( self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False ) -> None: """Checks the network for a unique service name, modifying the @@ -603,7 +642,7 @@ def check_service( next_instance_number = 2 next_time = now = current_time_millis() i = 0 - while i < 3: + while i < _REGISTER_BROADCASTS: # check for a name conflict while self.cache.current_entry_with_name_and_alias(info.type, info.name): if not allow_name_change: @@ -617,11 +656,11 @@ def check_service( i = 0 if now < next_time: - self.wait(next_time - now) + await self.async_wait(next_time - now) now = current_time_millis() continue - self.send_service_query(info) + self.async_send(self.generate_service_query(info)) i += 1 next_time += _CHECK_TIME diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py index c68c0f00..395c331b 100644 --- a/zeroconf/_utils/asyncio.py +++ b/zeroconf/_utils/asyncio.py @@ -23,8 +23,9 @@ import asyncio import contextlib import queue -from typing import Any, List, Optional, Set, cast +from typing import Any, Awaitable, List, Optional, Set, cast +# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT _TASK_AWAIT_TIMEOUT = 1 _GET_ALL_TASKS_TIMEOUT = 3 _WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT @@ -80,6 +81,12 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None: await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT) +async def await_awaitable(aw: Awaitable) -> None: + """Wait on an awaitable and the task it returns.""" + task = await aw + await task + + def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" pending_tasks = set( diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 67ff1c12..08478044 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -26,20 +26,15 @@ from ._core import Zeroconf from ._dns import DNSQuestionType -from ._exceptions import NonUniqueNameException from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase -from ._services.info import ServiceInfo, instance_name_from_service_info +from ._services.info import ServiceInfo from ._services.types import ZeroconfServiceTypes from ._utils.net import IPVersion, InterfaceChoice, InterfacesType -from ._utils.time import millis_to_seconds from .const import ( _BROWSER_TIME, - _CHECK_TIME, _MDNS_PORT, - _REGISTER_TIME, _SERVICE_TYPE_ENUMERATION_NAME, - _UNREGISTER_TIME, ) @@ -172,16 +167,11 @@ def __init__( ) self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {} - async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: - """Send a broadcasts to announce a service at intervals.""" - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(interval)) - self.zeroconf.async_send(self.zeroconf.generate_service_broadcast(info, ttl)) - async def async_register_service( self, info: ServiceInfo, + ttl: Optional[int] = None, + allow_name_change: bool = False, cooperating_responders: bool = False, ) -> Awaitable: """Registers service information to the network with a default TTL. @@ -194,10 +184,9 @@ async def async_register_service( The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - await self.zeroconf.async_wait_for_start() - await self.async_check_service(info, cooperating_responders) - self.zeroconf.registry.add(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return await self.zeroconf.async_register_service( + info, ttl, allow_name_change, cooperating_responders + ) async def async_unregister_all_services(self) -> None: """Unregister all registered services. @@ -206,30 +195,7 @@ async def async_unregister_all_services(self) -> None: method does not return a future and is always expected to be awaited since its only called at shutdown. """ - out = self.zeroconf.generate_unregister_all_services() - if not out: - return - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME)) - self.zeroconf.async_send(out) - - async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None: - """Checks the network for a unique service name.""" - instance_name_from_service_info(info) - if cooperating_responders: - return - self._raise_on_name_conflict(info) - for i in range(3): - if i != 0: - await asyncio.sleep(millis_to_seconds(_CHECK_TIME)) - self.zeroconf.async_send(self.zeroconf.generate_service_query(info)) - self._raise_on_name_conflict(info) - - def _raise_on_name_conflict(self, info: ServiceInfo) -> None: - """Raise NonUniqueNameException if the ServiceInfo has a conflict.""" - if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name): - raise NonUniqueNameException + await self.zeroconf.async_unregister_all_services() async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service. @@ -237,8 +203,7 @@ async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - self.zeroconf.registry.remove(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + return await self.zeroconf.async_unregister_service(info) async def async_update_service(self, info: ServiceInfo) -> Awaitable: """Registers service information to the network with a default TTL. @@ -248,8 +213,7 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: The service will be broadcast in a task. This task is returned and therefore can be awaited if necessary. """ - self.zeroconf.registry.update(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + return await self.zeroconf.async_update_service(info) async def async_close(self) -> None: """Ends the background threads, and prevent this instance from From 9399c57bb2b280c7b433e7fbea7cca2c2f4417ee Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 10:40:47 -1000 Subject: [PATCH 512/608] Centralize running coroutines from threads (#906) - Cleanup to ensure all coros we run from a thread use _LOADED_SYSTEM_TIMEOUT --- zeroconf/_core.py | 30 +++++++++++++++++------------- zeroconf/_services/info.py | 10 +++------- zeroconf/_utils/asyncio.py | 12 +++++++++++- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index a8986211..aadaa290 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -43,7 +43,13 @@ from ._services.info import ServiceInfo, instance_name_from_service_info from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener -from ._utils.asyncio import await_awaitable, get_running_loop, shutdown_loop, wait_event_or_timeout +from ._utils.asyncio import ( + await_awaitable, + get_running_loop, + run_coro_with_timeout, + shutdown_loop, + wait_event_or_timeout, +) from ._utils.name import service_type_name from ._utils.net import ( IPVersion, @@ -62,7 +68,6 @@ _FLAGS_AA, _FLAGS_QR_QUERY, _FLAGS_QR_RESPONSE, - _LOADED_SYSTEM_TIMEOUT, _MAX_MSG_ABSOLUTE, _MDNS_ADDR, _MDNS_ADDR6, @@ -73,7 +78,7 @@ ) _TC_DELAY_RANDOM_INTERVAL = (400, 500) -_CLOSE_TIMEOUT = 3 +_CLOSE_TIMEOUT = 3000 # ms _REGISTER_BROADCASTS = 3 @@ -174,9 +179,7 @@ def close(self) -> None: return if not self.loop.is_running(): return - asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result( - _CLOSE_TIMEOUT + _LOADED_SYSTEM_TIMEOUT - ) + run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT) class AsyncListener(asyncio.Protocol, QuietLogger): @@ -486,12 +489,13 @@ def register_service( can register the same service on the network for resilience (if you want this behavior set `cooperating_responders` to `True`).""" assert self.loop is not None - asyncio.run_coroutine_threadsafe( + run_coro_with_timeout( await_awaitable( self.async_register_service(info, ttl, allow_name_change, cooperating_responders) ), self.loop, - ).result(millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + _REGISTER_TIME * _REGISTER_BROADCASTS, + ) async def async_register_service( self, @@ -522,8 +526,8 @@ def update_service(self, info: ServiceInfo) -> None: Zeroconf will then respond to requests for information for that service.""" assert self.loop is not None - asyncio.run_coroutine_threadsafe(await_awaitable(self.async_update_service(info)), self.loop).result( - millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT + run_coro_with_timeout( + await_awaitable(self.async_update_service(info)), self.loop, _REGISTER_TIME * _REGISTER_BROADCASTS ) async def async_update_service(self, info: ServiceInfo) -> Awaitable: @@ -577,9 +581,9 @@ def _add_broadcast_answer( # pylint: disable=no-self-use def unregister_service(self, info: ServiceInfo) -> None: """Unregister a service.""" assert self.loop is not None - asyncio.run_coroutine_threadsafe( - await_awaitable(self.async_unregister_service(info)), self.loop - ).result(millis_to_seconds(_UNREGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT) + run_coro_with_timeout( + self.async_unregister_service(info), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service.""" diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 7bc81c8d..d1bf17e9 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -20,7 +20,6 @@ USA """ -import asyncio import ipaddress import socket from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast @@ -29,7 +28,7 @@ from .._exceptions import BadTypeInNameException from .._protocol import DNSOutgoing from .._updates import RecordUpdate, RecordUpdateListener -from .._utils.asyncio import get_running_loop +from .._utils.asyncio import get_running_loop, run_coro_with_timeout from .._utils.name import service_type_name from .._utils.net import ( IPVersion, @@ -37,7 +36,7 @@ _is_v6_address, ) from .._utils.struct import int2byte -from .._utils.time import current_time_millis, millis_to_seconds +from .._utils.time import current_time_millis from ..const import ( _CLASS_IN, _CLASS_UNIQUE, @@ -45,7 +44,6 @@ _DNS_OTHER_TTL, _FLAGS_QR_QUERY, _LISTENER_TIME, - _LOADED_SYSTEM_TIMEOUT, _TYPE_A, _TYPE_AAAA, _TYPE_PTR, @@ -426,9 +424,7 @@ def request( assert zc.loop is not None and zc.loop.is_running() if zc.loop == get_running_loop(): raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop") - return asyncio.run_coroutine_threadsafe( - self.async_request(zc, timeout, question_type), zc.loop - ).result(millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT) + return bool(run_coro_with_timeout(self.async_request(zc, timeout, question_type), zc.loop, timeout)) async def async_request( self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None diff --git a/zeroconf/_utils/asyncio.py b/zeroconf/_utils/asyncio.py index 395c331b..10b8b3d9 100644 --- a/zeroconf/_utils/asyncio.py +++ b/zeroconf/_utils/asyncio.py @@ -23,7 +23,10 @@ import asyncio import contextlib import queue -from typing import Any, Awaitable, List, Optional, Set, cast +from typing import Any, Awaitable, Coroutine, List, Optional, Set, cast + +from .time import millis_to_seconds +from ..const import _LOADED_SYSTEM_TIMEOUT # The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT _TASK_AWAIT_TIMEOUT = 1 @@ -87,6 +90,13 @@ async def await_awaitable(aw: Awaitable) -> None: await task +def run_coro_with_timeout(aw: Coroutine, loop: asyncio.AbstractEventLoop, timeout: float) -> Any: + """Run a coroutine with a timeout.""" + return asyncio.run_coroutine_threadsafe(aw, loop).result( + millis_to_seconds(timeout) + _LOADED_SYSTEM_TIMEOUT + ) + + def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None: """Wait for pending tasks and stop an event loop.""" pending_tasks = set( From bc9e9cf8a5b997ca924730ed091a829f4f961ca3 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 11:12:57 -1000 Subject: [PATCH 513/608] Implement NSEC record parsing (#903) - This is needed for negative responses https://datatracker.ietf.org/doc/html/rfc6762#section-6.1 --- tests/test_asyncio.py | 1 + tests/test_dns.py | 25 +++++++++++++++++++++++ tests/test_protocol.py | 15 ++++++++++++++ zeroconf/__init__.py | 1 + zeroconf/_dns.py | 46 ++++++++++++++++++++++++++++++++++++------ zeroconf/_protocol.py | 28 ++++++++++++++++++++++++- zeroconf/const.py | 2 ++ 7 files changed, 111 insertions(+), 7 deletions(-) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index f4722389..ac8f99f8 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -104,6 +104,7 @@ async def test_async_with_sync_passed_in_closed_in_async() -> None: @pytest.mark.asyncio async def test_sync_within_event_loop_executor() -> None: """Test sync version still works from an executor within an event loop.""" + def sync_code(): zc = Zeroconf(interfaces=['127.0.0.1']) assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None diff --git a/tests/test_dns.py b/tests/test_dns.py index 19735706..c55e0f6a 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -288,6 +288,31 @@ def test_dns_service_record_hashablity(): assert len(record_set) == 4 +def test_dns_nsec_record_hashablity(): + """Test DNSNsec are hashable.""" + nsec1 = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2, 3] + ) + nsec2 = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2] + ) + + record_set = set([nsec1, nsec2]) + assert len(record_set) == 2 + + record_set.add(nsec1) + assert len(record_set) == 2 + + nsec2_dupe = r.DNSNsec( + 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2] + ) + assert nsec2 == nsec2_dupe + assert nsec2.__hash__() == nsec2_dupe.__hash__() + + record_set.add(nsec2_dupe) + assert len(record_set) == 2 + + def test_rrset_does_not_consider_ttl(): """Test DNSRRSet does not consider the ttl in the hash.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index ebdb7110..a0805960 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -722,6 +722,21 @@ def test_qu_packet_parser(): assert ",QU," in str(parsed.questions[0]) +def test_parse_packet_with_nsec_record(): + """Test we can parse a packet with an NSEC record.""" + nsec_packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x03\x08_meshcop\x04_udp\x05local\x00\x00\x0c\x00" + b"\x01\x00\x00\x11\x94\x00\x0f\x0cMyHome54 (2)\xc0\x0c\xc0+\x00\x10\x80\x01\x00\x00\x11\x94\x00" + b")\x0bnn=MyHome54\x13xp=695034D148CC4784\x08tv=0.0.0\xc0+\x00!\x80\x01\x00\x00\x00x\x00\x15\x00" + b"\x00\x00\x00\xc0'\x0cMaster-Bed-2\xc0\x1a\xc0+\x00/\x80\x01\x00\x00\x11\x94\x00\t\xc0+\x00\x05" + b"\x00\x00\x80\x00@" + ) + parsed = DNSIncoming(nsec_packet) + nsec_record = parsed.answers[3] + assert "nsec," in str(nsec_record) + assert nsec_record.rdtypes == [16, 33] + + def test_records_same_packet_share_fate(): """Test records in the same packet all have the same created time.""" out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 56c61ff3..666914b2 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -28,6 +28,7 @@ DNSAddress, DNSEntry, DNSHinfo, + DNSNsec, DNSPointer, DNSQuestion, DNSRecord, diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 93db9859..31a2da3a 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -22,7 +22,7 @@ import enum import socket -from typing import Any, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address @@ -116,11 +116,7 @@ class DNSQuestion(DNSEntry): def answered_by(self, rec: 'DNSRecord') -> bool: """Returns true if the question is answered by the record""" - return ( - self.class_ == rec.class_ - and (self.type == rec.type or self.type == _TYPE_ANY) - and self.name == rec.name - ) + return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name def __hash__(self) -> int: return hash((self.name, self.class_, self.type)) @@ -446,6 +442,44 @@ def __repr__(self) -> str: return self.to_string("%s:%s" % (self.server, self.port)) +class DNSNsec(DNSRecord): + + """A DNS NSEC record""" + + __slots__ = ('next', 'rdtypes') + + def __init__( + self, + name: str, + type_: int, + class_: int, + ttl: int, + next: str, + rdtypes: List[int], + created: Optional[float] = None, + ) -> None: + super().__init__(name, type_, class_, ttl, created) + self.next = next + self.rdtypes = rdtypes + + def __eq__(self, other: Any) -> bool: + """Tests equality on cpu and os""" + return ( + isinstance(other, DNSNsec) + and self.next == other.next + and self.rdtypes == other.rdtypes + and DNSEntry.__eq__(self, other) + ) + + def __hash__(self) -> int: + """Hash to compare like DNSNSec.""" + return hash((*self._entry_tuple(), self.next, *self.rdtypes)) + + def __repr__(self) -> str: + """String representation""" + return self.to_string(self.next + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes])) + + class DNSRRSet: """A set of dns records independent of the ttl.""" diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 79f483de..ae2f43a3 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast -from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ._exceptions import IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log from ._utils.struct import int2byte @@ -43,6 +43,7 @@ _TYPE_AAAA, _TYPE_CNAME, _TYPE_HINFO, + _TYPE_NSEC, _TYPE_PTR, _TYPE_SRV, _TYPE_TXT, @@ -201,6 +202,18 @@ def read_others(self) -> None: rec = DNSAddress( domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id ) + elif type_ == _TYPE_NSEC: + name_start = self.offset + name = self.read_name() + rec = DNSNsec( + domain, + type_, + class_, + ttl, + name, + self.read_bitmap(name_start + length), + self.now, + ) else: # Try to ignore types we don't know about # Skip the payload for the resource record so the next @@ -210,6 +223,19 @@ def read_others(self) -> None: if rec is not None: self.answers.append(rec) + def read_bitmap(self, end: int) -> List[int]: + """Reads an NSEC bitmap from the packet.""" + rdtypes = [] + while self.offset < end: + window = self.data[self.offset] + bitmap_length = self.data[self.offset + 1] + for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]): + for bit in range(0, 8): + if byte & (0x80 >> bit): + rdtypes.append(bit + window * 256 + i * 8) + self.offset += 2 + bitmap_length + return rdtypes + def read_name(self) -> str: """Reads a domain name from the packet""" result = '' diff --git a/zeroconf/const.py b/zeroconf/const.py index afdcb2d4..76a75dbd 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -103,6 +103,7 @@ _TYPE_TXT = 16 _TYPE_AAAA = 28 _TYPE_SRV = 33 +_TYPE_NSEC = 47 _TYPE_ANY = 255 # Mapping constants to names @@ -136,6 +137,7 @@ _TYPE_AAAA: "quada", _TYPE_SRV: "srv", _TYPE_ANY: "any", + _TYPE_NSEC: "nsec", } _HAS_A_TO_Z = re.compile(r'[A-Za-z]') From 057873128ff05a0b2d6eae07510e23d705d10bae Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 11:23:14 -1000 Subject: [PATCH 514/608] Upgrade syntax to python 3.6 (#907) --- examples/async_apple_scanner.py | 6 ++-- examples/async_browser.py | 6 ++-- examples/async_service_info_request.py | 4 +-- examples/browser.py | 6 ++-- examples/self_test.py | 4 +-- tests/conftest.py | 1 - tests/test_asyncio.py | 35 +++++++++--------- tests/test_cache.py | 21 ++++++----- tests/test_core.py | 43 +++++++++++----------- tests/test_dns.py | 23 ++++++------ tests/test_exceptions.py | 7 ++-- tests/test_handlers.py | 49 +++++++++++++------------- tests/test_history.py | 37 ++++++++----------- tests/test_init.py | 7 ++-- tests/test_logger.py | 1 - tests/test_protocol.py | 5 ++- tests/test_services.py | 3 +- tests/test_updates.py | 3 +- zeroconf/_core.py | 2 +- zeroconf/_dns.py | 8 ++--- zeroconf/_handlers.py | 18 +++++----- zeroconf/_protocol.py | 6 ++-- zeroconf/asyncio.py | 2 +- 23 files changed, 136 insertions(+), 161 deletions(-) diff --git a/examples/async_apple_scanner.py b/examples/async_apple_scanner.py index f10f6ef6..88b54e4a 100644 --- a/examples/async_apple_scanner.py +++ b/examples/async_apple_scanner.py @@ -36,7 +36,7 @@ def async_on_service_state_change( zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange ) -> None: - print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + print(f"Service {name} of type {service_type} state changed: {state_change}") if state_change is not ServiceStateChange.Added: return base_name = name[: -len(service_type) - 1] @@ -55,11 +55,11 @@ async def _async_show_service_info(zeroconf: Zeroconf, service_type: str, name: print(" Name: %s" % name) print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) - print(" Server: %s" % (info.server,)) + print(f" Server: {info.server}") if info.properties: print(" Properties are:") for key, value in info.properties.items(): - print(" %s: %s" % (key, value)) + print(f" {key}: {value}") else: print(" No properties") else: diff --git a/examples/async_browser.py b/examples/async_browser.py index f0e0851c..1cce5c20 100644 --- a/examples/async_browser.py +++ b/examples/async_browser.py @@ -17,7 +17,7 @@ def async_on_service_state_change( zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange ) -> None: - print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + print(f"Service {name} of type {service_type} state changed: {state_change}") if state_change is not ServiceStateChange.Added: return asyncio.ensure_future(async_display_service_info(zeroconf, service_type, name)) @@ -32,11 +32,11 @@ async def async_display_service_info(zeroconf: Zeroconf, service_type: str, name print(" Name: %s" % name) print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) - print(" Server: %s" % (info.server,)) + print(f" Server: {info.server}") if info.properties: print(" Properties are:") for key, value in info.properties.items(): - print(" %s: %s" % (key, value)) + print(f" {key}: {value}") else: print(" No properties") else: diff --git a/examples/async_service_info_request.py b/examples/async_service_info_request.py index 885eb99c..dd8265b7 100644 --- a/examples/async_service_info_request.py +++ b/examples/async_service_info_request.py @@ -36,11 +36,11 @@ async def async_watch_services(aiozc: AsyncZeroconf) -> None: addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_addresses()] print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) - print(" Server: %s" % (info.server,)) + print(f" Server: {info.server}") if info.properties: print(" Properties are:") for key, value in info.properties.items(): - print(" %s: %s" % (key, value)) + print(f" {key}: {value}") else: print(" No properties") else: diff --git a/examples/browser.py b/examples/browser.py index 8525e9b9..8c50e409 100755 --- a/examples/browser.py +++ b/examples/browser.py @@ -16,7 +16,7 @@ def on_service_state_change( zeroconf: Zeroconf, service_type: str, name: str, state_change: ServiceStateChange ) -> None: - print("Service %s of type %s state changed: %s" % (name, service_type, state_change)) + print(f"Service {name} of type {service_type} state changed: {state_change}") if state_change is ServiceStateChange.Added: info = zeroconf.get_service_info(service_type, name) @@ -26,11 +26,11 @@ def on_service_state_change( addresses = ["%s:%d" % (addr, cast(int, info.port)) for addr in info.parsed_scoped_addresses()] print(" Addresses: %s" % ", ".join(addresses)) print(" Weight: %d, priority: %d" % (info.weight, info.priority)) - print(" Server: %s" % (info.server,)) + print(f" Server: {info.server}") if info.properties: print(" Properties are:") for key, value in info.properties.items(): - print(" %s: %s" % (key, value)) + print(f" {key}: {value}") else: print(" No properties") else: diff --git a/examples/self_test.py b/examples/self_test.py index 35007db1..2178629b 100755 --- a/examples/self_test.py +++ b/examples/self_test.py @@ -14,7 +14,7 @@ # Test a few module features, including service registration, service # query (for Zoe), and service unregistration. - print("Multicast DNS Service Discovery for Python, version %s" % (__version__,)) + print(f"Multicast DNS Service Discovery for Python, version {__version__}") r = Zeroconf() print("1. Testing registration of a service...") desc = {'version': '0.10', 'a': 'test value', 'b': 'another value'} @@ -40,7 +40,7 @@ queried_info = r.get_service_info("_http._tcp.local.", "My Service Name._http._tcp.local.") assert queried_info assert set(queried_info.parsed_addresses()) == expected - print(" Getting self: %s" % (queried_info,)) + print(f" Getting self: {queried_info}") print(" Query done.") print("4. Testing unregister of service information...") r.unregister_service(info) diff --git a/tests/conftest.py b/tests/conftest.py index c05c4b9b..d4ea1632 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ conftest for zeroconf tests. """ diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index ac8f99f8..34709d85 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Unit tests for aio.py.""" @@ -57,11 +56,9 @@ def verify_threads_ended(): yield threads_after = frozenset(threading.enumerate()) non_executor_threads = frozenset( - [ - thread - for thread in threads_after - if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name - ] + thread + for thread in threads_after + if "asyncio" not in thread.name and "ThreadPoolExecutor" not in thread.name ) threads = non_executor_threads - threads_before assert not threads @@ -119,7 +116,7 @@ async def test_async_service_registration() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test1-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" calls = [] @@ -179,7 +176,7 @@ async def test_async_service_registration_name_conflict() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test-srvc2-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -227,7 +224,7 @@ async def test_async_service_registration_name_does_not_match_type() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test-srvc3-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -254,7 +251,7 @@ async def test_async_tasks() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test-srvc4-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" calls = [] @@ -320,7 +317,7 @@ async def test_async_wait_unblocks_on_update() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test-srvc4-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -356,8 +353,8 @@ async def test_service_info_async_request() -> None: type_ = "_test1-srvc-type._tcp.local." name = "xxxyyy" name2 = "abc" - registration_name = "%s.%s" % (name, type_) - registration_name2 = "%s.%s" % (name2, type_) + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" # Start a tasks BEFORE the registration that will keep trying # and see the registration a bit later @@ -454,7 +451,7 @@ async def test_async_service_browser() -> None: aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test9-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" calls = [] @@ -513,7 +510,7 @@ async def test_async_context_manager() -> None: """Test using an async context manager.""" type_ = "_test10-sr-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" async with AsyncZeroconf(interfaces=['127.0.0.1']) as aiozc: info = ServiceInfo( @@ -539,8 +536,8 @@ async def test_async_unregister_all_services() -> None: type_ = "_test1-srvc-type._tcp.local." name = "xxxyyy" name2 = "abc" - registration_name = "%s.%s" % (name, type_) - registration_name2 = "%s.%s" % (name2, type_) + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -594,7 +591,7 @@ async def test_async_unregister_all_services() -> None: async def test_async_zeroconf_service_types(): type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" zeroconf_registrar = AsyncZeroconf(interfaces=['127.0.0.1']) desc = {'path': '/~paulsm/'} @@ -808,7 +805,7 @@ async def test_info_asking_default_is_asking_qm_questions_after_the_first_qu(): zeroconf_info = aiozc.zeroconf name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( diff --git a/tests/test_cache.py b/tests/test_cache.py index 4b3a8a18..559b4357 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._cache. """ @@ -98,7 +97,7 @@ def test_async_all_by_details(self): record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) + assert set(cache.async_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2} def test_async_entries_with_server(self): record1 = r.DNSService( @@ -109,8 +108,8 @@ def test_async_entries_with_server(self): ) cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.async_entries_with_server('ab')) == set([record1, record2]) - assert set(cache.async_entries_with_server('AB')) == set([record1, record2]) + assert set(cache.async_entries_with_server('ab')) == {record1, record2} + assert set(cache.async_entries_with_server('AB')) == {record1, record2} def test_async_entries_with_name(self): record1 = r.DNSService( @@ -121,8 +120,8 @@ def test_async_entries_with_name(self): ) cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.async_entries_with_name('irrelevant')) == set([record1, record2]) - assert set(cache.async_entries_with_name('Irrelevant')) == set([record1, record2]) + assert set(cache.async_entries_with_name('irrelevant')) == {record1, record2} + assert set(cache.async_entries_with_name('Irrelevant')) == {record1, record2} # These functions have been seen in other projects so @@ -152,7 +151,7 @@ def test_get_all_by_details(self): record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b') cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2]) + assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == {record1, record2} def test_entries_with_server(self): record1 = r.DNSService( @@ -163,8 +162,8 @@ def test_entries_with_server(self): ) cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.entries_with_server('ab')) == set([record1, record2]) - assert set(cache.entries_with_server('AB')) == set([record1, record2]) + assert set(cache.entries_with_server('ab')) == {record1, record2} + assert set(cache.entries_with_server('AB')) == {record1, record2} def test_entries_with_name(self): record1 = r.DNSService( @@ -175,8 +174,8 @@ def test_entries_with_name(self): ) cache = r.DNSCache() cache.async_add_records([record1, record2]) - assert set(cache.entries_with_name('irrelevant')) == set([record1, record2]) - assert set(cache.entries_with_name('Irrelevant')) == set([record1, record2]) + assert set(cache.entries_with_name('irrelevant')) == {record1, record2} + assert set(cache.entries_with_name('Irrelevant')) == {record1, record2} def test_current_entry_with_name_and_alias(self): record1 = r.DNSPointer( diff --git a/tests/test_core.py b/tests/test_core.py index d80514f7..fee9c79d 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._core """ @@ -55,28 +54,26 @@ async def test_reaper(): aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) zeroconf = aiozc.zeroconf cache = zeroconf.cache - original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + original_entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) now = r.current_time_millis() - other_known_answers = set( - [ - r.DNSPointer( - "_hap._tcp.local.", - const._TYPE_PTR, - const._CLASS_IN, - 10000, - 'known-to-other._hap._tcp.local.', - ) - ] - ) + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN, + 10000, + 'known-to-other._hap._tcp.local.', + ) + } zeroconf.question_history.add_question_at_time(question, now, other_known_answers) assert zeroconf.question_history.suppresses(question, now, other_known_answers) - entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + entries_with_cache = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) await asyncio.sleep(1.2) - entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()])) + entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) await aiozc.async_close() assert not zeroconf.question_history.suppresses(question, now, other_known_answers) assert entries != original_entries @@ -367,7 +364,7 @@ def test_register_service_with_custom_ttl(): name = "MyTestHome" info_service = r.ServiceInfo( type_, - '%s.%s' % (name, type_), + f'{name}.{type_}', 80, 0, 0, @@ -423,9 +420,9 @@ def test_tc_bit_defers(): name2 = "knownname2" name3 = "knownname3" - registration_name = "%s.%s" % (name, type_) - registration2_name = "%s.%s" % (name2, type_) - registration3_name = "%s.%s" % (name3, type_) + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." @@ -502,9 +499,9 @@ def test_tc_bit_defers_last_response_missing(): name2 = "knownname2" name3 = "knownname3" - registration_name = "%s.%s" % (name, type_) - registration2_name = "%s.%s" % (name2, type_) - registration3_name = "%s.%s" % (name3, type_) + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." @@ -729,7 +726,7 @@ def test_shutdown_while_register_in_process(): name = "MyTestHome" info_service = r.ServiceInfo( type_, - '%s.%s' % (name, type_), + f'{name}.{type_}', 80, 0, 0, diff --git a/tests/test_dns.py b/tests/test_dns.py index c55e0f6a..071e1f65 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._dns. """ @@ -101,7 +100,7 @@ def test_dns_record_reset_ttl(self): def test_service_info_dunder(self): type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" info = ServiceInfo( type_, registration_name, @@ -119,7 +118,7 @@ def test_service_info_dunder(self): def test_service_info_text_properties_not_given(self): type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" info = ServiceInfo( type_=type_, name=registration_name, @@ -166,7 +165,7 @@ def test_dns_record_hashablity_does_not_consider_ttl(): record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same') record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same') - record_set = set([record1, record2]) + record_set = {record1, record2} assert len(record_set) == 1 record_set.add(record1) @@ -187,7 +186,7 @@ def test_dns_address_record_hashablity(): address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c') address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c') - record_set = set([address1, address2, address3, address4]) + record_set = {address1, address2, address3, address4} assert len(record_set) == 4 record_set.add(address1) @@ -199,9 +198,9 @@ def test_dns_address_record_hashablity(): assert len(record_set) == 4 # Verify we can remove records - additional_set = set([address1, address2]) + additional_set = {address1, address2} record_set -= additional_set - assert record_set == set([address3, address4]) + assert record_set == {address3, address4} def test_dns_hinfo_record_hashablity(): @@ -209,7 +208,7 @@ def test_dns_hinfo_record_hashablity(): hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os') hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os') - record_set = set([hinfo1, hinfo2]) + record_set = {hinfo1, hinfo2} assert len(record_set) == 2 record_set.add(hinfo1) @@ -228,7 +227,7 @@ def test_dns_pointer_record_hashablity(): ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456') - record_set = set([ptr1, ptr2]) + record_set = {ptr1, ptr2} assert len(record_set) == 2 record_set.add(ptr1) @@ -249,7 +248,7 @@ def test_dns_text_record_hashablity(): text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901') text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK') - record_set = set([text1, text2, text3, text4]) + record_set = {text1, text2, text3, text4} assert len(record_set) == 4 @@ -271,7 +270,7 @@ def test_dns_service_record_hashablity(): srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a') srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab') - record_set = set([srv1, srv2, srv3, srv4]) + record_set = {srv1, srv2, srv3, srv4} assert len(record_set) == 4 @@ -297,7 +296,7 @@ def test_dns_nsec_record_hashablity(): 'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2] ) - record_set = set([nsec1, nsec2]) + record_set = {nsec1, nsec2} assert len(record_set) == 2 record_set.add(nsec1) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index aa2f74f6..47e68b75 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._exceptions """ @@ -107,7 +106,7 @@ def test_bad_types(self): bad_names_to_try = ( '._x._tcp.local.', 'a' * 64 + '._sub._http._tcp.local.', - 'a' * 62 + u'â._sub._http._tcp.local.', + 'a' * 62 + 'â._sub._http._tcp.local.', ) for name in bad_names_to_try: self.assertRaises(r.BadTypeInNameException, r.service_type_name, name) @@ -129,7 +128,7 @@ def test_good_service_names(self): ('_12345-67890-abc._udp.local.', '_12345-67890-abc._udp.local.'), ('x._sub._http._tcp.local.', '_http._tcp.local.'), ('a' * 63 + '._sub._http._tcp.local.', '_http._tcp.local.'), - ('a' * 61 + u'â._sub._http._tcp.local.', '_http._tcp.local.'), + ('a' * 61 + 'â._sub._http._tcp.local.', '_http._tcp.local.'), ) for name, result in good_names_to_try: @@ -140,7 +139,7 @@ def test_good_service_names(self): def test_invalid_addresses(self): type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" bad = ('127.0.0.1', '::1', 42) for addr in bad: diff --git a/tests/test_handlers.py b/tests/test_handlers.py index bab50a55..e90a74bd 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._handlers """ @@ -46,7 +45,7 @@ def test_ttl(self): # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -160,7 +159,7 @@ def test_name_conflicts(self): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_homeassistant._tcp.local." name = "Home" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" info = ServiceInfo( type_, @@ -189,7 +188,7 @@ def test_register_and_lookup_type_by_uppercase_name(self): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_mylowertype._tcp.local." name = "Home" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" info = ServiceInfo( type_, @@ -224,7 +223,7 @@ def test_ptr_optimization(): # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( @@ -280,7 +279,7 @@ def test_any_query_for_ptr(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_anyptr._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") @@ -307,7 +306,7 @@ def test_aaaa_query(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownaaaservice._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") @@ -332,7 +331,7 @@ def test_a_and_aaaa_record_fate_sharing(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_a-and-aaaa-service._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") @@ -388,7 +387,7 @@ def test_unicast_response(): # service definition type_ = "_test-srvc-type._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] @@ -434,8 +433,8 @@ def test_qu_response(): type_ = "_test-srvc-type._tcp.local." other_type_ = "_notthesame._tcp.local." name = "xxxyyy" - registration_name = "%s.%s" % (name, type_) - registration_name2 = "%s.%s" % (name, other_type_) + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{other_type_}" desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] @@ -530,7 +529,7 @@ def test_known_answer_supression(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownanswersv8._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." info = ServiceInfo( @@ -643,9 +642,9 @@ def test_multi_packet_known_answer_supression(): name2 = "knownname2" name3 = "knownname3" - registration_name = "%s.%s" % (name, type_) - registration2_name = "%s.%s" % (name2, type_) - registration3_name = "%s.%s" % (name3, type_) + registration_name = f"{name}.{type_}" + registration2_name = f"{name2}.{type_}" + registration3_name = f"{name3}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." @@ -694,7 +693,7 @@ def test_known_answer_supression_service_type_enumeration_query(): zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_otherknown._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." info = ServiceInfo( @@ -704,7 +703,7 @@ def test_known_answer_supression_service_type_enumeration_query(): type_2 = "_otherknown2._tcp.local." name = "knownname" - registration_name2 = "%s.%s" % (name, type_2) + registration_name2 = f"{name}.{type_2}" desc = {'path': '/~paulsm/'} server_name2 = "ash-3.local." info2 = ServiceInfo( @@ -772,7 +771,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): type_ = "_addtest1._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "ash-2.local." info = ServiceInfo( @@ -782,7 +781,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): type_2 = "_addtest2._tcp.local." name = "knownname" - registration_name2 = "%s.%s" % (name, type_2) + registration_name2 = f"{name}.{type_2}" desc = {'path': '/~paulsm/'} server_name2 = "ash-3.local." info2 = ServiceInfo( @@ -904,7 +903,7 @@ async def test_cache_flush_bit(): type_ = "_cacheflush._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "server-uu1.local." info = ServiceInfo( @@ -992,7 +991,7 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.Recor type_ = "_cacheflush._tcp.local." name = "knownname" - registration_name = "%s.%s" % (name, type_) + registration_name = f"{name}.{type_}" desc = {'path': '/~paulsm/'} server_name = "server-uu1.local." info = ServiceInfo( @@ -1013,11 +1012,11 @@ def async_update_records(self, zc: 'Zeroconf', now: float, records: List[r.Recor ) await asyncio.sleep(0) # flush out the call_soon_threadsafe - assert set([record.new for record in updated]) == set([ptr_record, a_record]) + assert {record.new for record in updated} == {ptr_record, a_record} # The old records should be None so we trigger Add events # in service browsers instead of Update events - assert set([record.old for record in updated]) == set([None]) + assert {record.old for record in updated} == {None} await aiozc.async_close() @@ -1044,7 +1043,7 @@ async def test_questions_query_handler_populates_the_question_history_from_qm_qu ) assert unicast_out is None assert multicast_out is None - assert zc.question_history.suppresses(question, now, set([known_answer])) + assert zc.question_history.suppresses(question, now, {known_answer}) await aiozc.async_close() @@ -1071,7 +1070,7 @@ async def test_questions_query_handler_does_not_put_qu_questions_in_history(): ) assert unicast_out is None assert multicast_out is None - assert not zc.question_history.suppresses(question, now, set([known_answer])) + assert not zc.question_history.suppresses(question, now, {known_answer}) await aiozc.async_close() diff --git a/tests/test_history.py b/tests/test_history.py index 89159dff..9da6b567 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Unit tests for _history.py.""" @@ -14,20 +13,16 @@ def test_question_suppression(): question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) now = r.current_time_millis() - other_known_answers = set( - [ - r.DNSPointer( - "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' - ) - ] - ) - our_known_answers = set( - [ - r.DNSPointer( - "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-us._hap._tcp.local.' - ) - ] - ) + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + } + our_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-us._hap._tcp.local.' + ) + } history.add_question_at_time(question, now, other_known_answers) @@ -52,13 +47,11 @@ def test_question_expire(): question = r.DNSQuestion("_hap._tcp._local.", const._TYPE_PTR, const._CLASS_IN) now = r.current_time_millis() - other_known_answers = set( - [ - r.DNSPointer( - "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' - ) - ] - ) + other_known_answers = { + r.DNSPointer( + "_hap._tcp.local.", const._TYPE_PTR, const._CLASS_IN, 10000, 'known-to-other._hap._tcp.local.' + ) + } history.add_question_at_time(question, now, other_known_answers) # Verify the question is suppressed if the known answers are the same diff --git a/tests/test_init.py b/tests/test_init.py index 5005a75d..1d1f7086 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf.py """ @@ -127,7 +126,7 @@ def verify_name_change(self, zc, type_, name, number_hosts): desc = {'path': '/~paulsm/'} info_service = ServiceInfo( type_, - '%s.%s' % (name, type_), + f'{name}.{type_}', 80, 0, 0, @@ -147,7 +146,7 @@ def verify_name_change(self, zc, type_, name, number_hosts): # in the registry info_service2 = ServiceInfo( type_, - '%s.%s' % (name, type_), + f'{name}.{type_}', 80, 0, 0, @@ -160,7 +159,7 @@ def verify_name_change(self, zc, type_, name, number_hosts): def generate_many_hosts(self, zc, type_, name, number_hosts): block_size = 25 - number_hosts = int(((number_hosts - 1) / block_size + 1)) * block_size + number_hosts = int((number_hosts - 1) / block_size + 1) * block_size out = r.DNSOutgoing(const._FLAGS_QR_RESPONSE | const._FLAGS_AA) for i in range(1, number_hosts + 1): next_name = name if i == 1 else '%s-%d' % (name, i) diff --git a/tests/test_logger.py b/tests/test_logger.py index 205ce0ff..cedda7e9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """Unit tests for logger.py.""" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a0805960..e9063475 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._protocol """ @@ -195,8 +194,8 @@ def test_dns_hinfo(self): generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'os')) parsed = r.DNSIncoming(generated.packets()[0]) answer = cast(r.DNSHinfo, parsed.answers[0]) - assert answer.cpu == u'cpu' - assert answer.os == u'os' + assert answer.cpu == 'cpu' + assert answer.os == 'os' generated = r.DNSOutgoing(0) generated.add_additional_answer(DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu', 'x' * 257)) diff --git a/tests/test_services.py b/tests/test_services.py index b1e2d890..7994cbdc 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._services. """ @@ -48,7 +47,7 @@ def test_integration_with_listener_class(self): type_ = "_http._tcp.local." subtype = subtype_name + "._sub." + type_ name = "UPPERxxxyyyæøå" - registration_name = "%s.%s" % (name, subtype) + registration_name = f"{name}.{subtype}" class MyListener(r.ServiceListener): def add_service(self, zeroconf, type, name): diff --git a/tests/test_updates.py b/tests/test_updates.py index 1f6f8ad4..b1d7f1b7 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- """ Unit tests for zeroconf._services. """ @@ -67,7 +66,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): info_service = ServiceInfo( type_, - '%s.%s' % (name, type_), + f'{name}.{type_}', 80, 0, 0, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index aadaa290..31bf2b32 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -653,7 +653,7 @@ async def async_check_service( raise NonUniqueNameException # change the name and look for a conflict - info.name = '%s-%s.%s' % (instance_name, next_instance_number, info.type) + info.name = f'{instance_name}-{next_instance_number}.{info.type}' next_instance_number += 1 service_type_name(info.name) next_time = now diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 31a2da3a..33484bfb 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -100,7 +100,7 @@ def get_type(t: int) -> str: def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: """String representation with additional information""" - return "%s[%s,%s%s,%s]%s" % ( + return "{}[{},{}{},{}]{}".format( hdr, self.get_type(self.type), self.get_class_(self.class_), @@ -142,7 +142,7 @@ def unicast(self, value: bool) -> None: def __repr__(self) -> str: """String representation""" - return "%s[question,%s,%s,%s]" % ( + return "{}[question,{},{},{}]".format( self.get_type(self.type), "QU" if self.unicast else "QM", self.get_class_(self.class_), @@ -230,7 +230,7 @@ def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use def to_string(self, other: Union[bytes, str]) -> str: """String representation with additional information""" - arg = "%s/%s,%s" % (self.ttl, int(self.get_remaining_ttl(current_time_millis())), cast(Any, other)) + arg = f"{self.ttl}/{int(self.get_remaining_ttl(current_time_millis()))},{cast(Any, other)}" return DNSEntry.entry_to_string(self, "record", arg) @@ -439,7 +439,7 @@ def __hash__(self) -> int: def __repr__(self) -> str: """String representation""" - return self.to_string("%s:%s" % (self.server, self.port)) + return self.to_string(f"{self.server}:{self.port}") class DNSNsec(DNSRecord): diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 617be408..f9bdab74 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -140,11 +140,11 @@ def _construct_outgoing_from_record_set( def _additionals_from_answers_rrset(self, rrset: Set[DNSRecord]) -> Set[DNSRecord]: additionals: Set[DNSRecord] = set() - return additionals.union(*[self._additionals[record] for record in rrset]) + return additionals.union(*(self._additionals[record] for record in rrset)) def _suppress_mcasts_from_last_second(self, rrset: Set[DNSRecord]) -> None: """Remove any records that were already sent in the last second.""" - rrset -= set(record for record in rrset if self._has_mcast_record_in_last_second(record)) + rrset -= {record for record in rrset if self._has_mcast_record_in_last_second(record)} def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: """Check to see if a record has been mcasted recently. @@ -201,13 +201,11 @@ def _add_pointer_answers( # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) if not known_answers.suppresses(dns_pointer): - answer_set[dns_pointer] = set( - [ - service.dns_service(created=now), - service.dns_text(created=now), - *service.dns_addresses(created=now), - ] - ) + answer_set[dns_pointer] = { + service.dns_service(created=now), + service.dns_text(created=now), + *service.dns_addresses(created=now), + } def _add_address_answers( self, @@ -271,7 +269,7 @@ def async_response( # pylint: disable=unused-argument threadsafe. """ ucast_source = port != _MDNS_PORT - known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs])) + known_answers = DNSRRSet(itertools.chain(*(msg.answers for msg in msgs))) query_res = _QueryResponse(self.cache, msgs[0], ucast_source) for msg in msgs: diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index ae2f43a3..f987ce2e 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -259,10 +259,10 @@ def read_name(self) -> str: next_ = off + 1 off = ((length & 0x3F) << 8) | self.data[off] if off >= first: - raise IncomingDecodeError("Bad domain name (circular) at %s" % (off,)) + raise IncomingDecodeError(f"Bad domain name (circular) at {off}") first = off else: - raise IncomingDecodeError("Bad domain name at %s" % (off,)) + raise IncomingDecodeError(f"Bad domain name at {off}") if next_ >= 0: self.offset = next_ @@ -523,7 +523,7 @@ def _write_record(self, record: DNSRecord, now: float) -> bool: self.write_short(0) # Will get replaced with the actual size record.write(self) # Adjust size for the short we will write before this record - length = sum((len(d) for d in self.data[index + 1 :])) + length = sum(len(d) for d in self.data[index + 1 :]) # Here we replace the 0 length short we wrote # before with the actual length self._replace_short(index, length) diff --git a/zeroconf/asyncio.py b/zeroconf/asyncio.py index 08478044..ef7e7f64 100644 --- a/zeroconf/asyncio.py +++ b/zeroconf/asyncio.py @@ -251,7 +251,7 @@ async def async_remove_service_listener(self, listener: ServiceListener) -> None async def async_remove_all_service_listeners(self) -> None: """Removes a listener from the set that is currently listening.""" await asyncio.gather( - *[self.async_remove_service_listener(listener) for listener in list(self.async_browsers)] + *(self.async_remove_service_listener(listener) for listener in list(self.async_browsers)) ) async def __aenter__(self) -> 'AsyncZeroconf': From 69942d5bfb4d92c6a312aea7c17f63fce0401e23 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 11:31:40 -1000 Subject: [PATCH 515/608] Rename DNSNsec.next to DNSNsec.next_name (#908) --- tests/test_protocol.py | 1 + zeroconf/_dns.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e9063475..75a69d5e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -734,6 +734,7 @@ def test_parse_packet_with_nsec_record(): nsec_record = parsed.answers[3] assert "nsec," in str(nsec_record) assert nsec_record.rdtypes == [16, 33] + assert nsec_record.next_name == "MyHome54 (2)._meshcop._udp.local." def test_records_same_packet_share_fate(): diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 33484bfb..7e06cff4 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -446,7 +446,7 @@ class DNSNsec(DNSRecord): """A DNS NSEC record""" - __slots__ = ('next', 'rdtypes') + __slots__ = ('next_name', 'rdtypes') def __init__( self, @@ -454,30 +454,32 @@ def __init__( type_: int, class_: int, ttl: int, - next: str, + next_name: str, rdtypes: List[int], created: Optional[float] = None, ) -> None: super().__init__(name, type_, class_, ttl, created) - self.next = next + self.next_name = next_name self.rdtypes = rdtypes def __eq__(self, other: Any) -> bool: """Tests equality on cpu and os""" return ( isinstance(other, DNSNsec) - and self.next == other.next + and self.next_name == other.next_name and self.rdtypes == other.rdtypes and DNSEntry.__eq__(self, other) ) def __hash__(self) -> int: """Hash to compare like DNSNSec.""" - return hash((*self._entry_tuple(), self.next, *self.rdtypes)) + return hash((*self._entry_tuple(), self.next_name, *self.rdtypes)) def __repr__(self) -> str: """String representation""" - return self.to_string(self.next + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes])) + return self.to_string( + self.next_name + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes]) + ) class DNSRRSet: From e63ca518c91cda7b9f460436aee4fdac1a7b9567 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 19:55:02 -1000 Subject: [PATCH 516/608] Remove duplicate unregister_all_services code (#910) --- zeroconf/_core.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 31bf2b32..1d3e6cba 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -619,21 +619,10 @@ async def async_unregister_all_services(self) -> None: def unregister_all_services(self) -> None: """Unregister all registered services.""" - # Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1 - out = self.generate_unregister_all_services() - if not out: - return - now = current_time_millis() - next_time = now - i = 0 - while i < 3: - if now < next_time: - self.wait(next_time - now) - now = current_time_millis() - continue - self.send(out) - i += 1 - next_time += _UNREGISTER_TIME + assert self.loop is not None + run_coro_with_timeout( + self.async_unregister_all_services(), self.loop, _UNREGISTER_TIME * _REGISTER_BROADCASTS + ) async def async_check_service( self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False @@ -799,7 +788,14 @@ def close(self) -> None: This method is idempotent and irreversible. """ - self.unregister_all_services() + assert self.loop is not None + if self.loop.is_running(): + if self.loop == get_running_loop(): + log.warning( + "unregister_all_services skipped as it does blocking i/o; use AsyncZeroconf with asyncio" + ) + else: + self.unregister_all_services() self._close() self.engine.close() self._shutdown_threads() From 2d3da7a77699f88bd90ebc09d36b333690385f85 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 20:20:02 -1000 Subject: [PATCH 517/608] Remove locking from ServiceRegistry (#911) - All calls to the ServiceRegistry are now done in async context which makes them thread safe. Locking is no longer needed. --- tests/services/test_registry.py | 52 +++++++++++++++---------------- tests/services/test_types.py | 8 ++--- tests/test_asyncio.py | 2 +- tests/test_core.py | 20 ++++++------ tests/test_handlers.py | 54 ++++++++++++++++----------------- zeroconf/_core.py | 10 +++--- zeroconf/_handlers.py | 8 ++--- zeroconf/_services/registry.py | 51 +++++++++++-------------------- 8 files changed, 95 insertions(+), 110 deletions(-) diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 496cc629..87c048d5 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -23,10 +23,10 @@ def test_only_register_once(self): ) registry = r.ServiceRegistry() - registry.add(info) - self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) - registry.remove(info) - registry.add(info) + registry.async_add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info) + registry.async_remove(info) + registry.async_add(info) def test_unregister_multiple_times(self): """Verify we can unregister a service multiple times. @@ -46,10 +46,10 @@ def test_unregister_multiple_times(self): ) registry = r.ServiceRegistry() - registry.add(info) - self.assertRaises(r.ServiceNameAlreadyRegistered, registry.add, info) - registry.remove(info) - registry.remove(info) + registry.async_add(info) + self.assertRaises(r.ServiceNameAlreadyRegistered, registry.async_add, info) + registry.async_remove(info) + registry.async_remove(info) def test_lookups(self): type_ = "_test-srvc-type._tcp.local." @@ -62,13 +62,13 @@ def test_lookups(self): ) registry = r.ServiceRegistry() - registry.add(info) + registry.async_add(info) - assert registry.get_service_infos() == [info] - assert registry.get_info_name(registration_name) == info - assert registry.get_infos_type(type_) == [info] - assert registry.get_infos_server("ash-2.local.") == [info] - assert registry.get_types() == [type_] + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name) == info + assert registry.async_get_infos_type(type_) == [info] + assert registry.async_get_infos_server("ash-2.local.") == [info] + assert registry.async_get_types() == [type_] def test_lookups_upper_case_by_lower_case(self): type_ = "_test-SRVC-type._tcp.local." @@ -81,13 +81,13 @@ def test_lookups_upper_case_by_lower_case(self): ) registry = r.ServiceRegistry() - registry.add(info) + registry.async_add(info) - assert registry.get_service_infos() == [info] - assert registry.get_info_name(registration_name.lower()) == info - assert registry.get_infos_type(type_.lower()) == [info] - assert registry.get_infos_server("ash-2.local.") == [info] - assert registry.get_types() == [type_.lower()] + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name.lower()) == info + assert registry.async_get_infos_type(type_.lower()) == [info] + assert registry.async_get_infos_server("ash-2.local.") == [info] + assert registry.async_get_types() == [type_.lower()] def test_lookups_lower_case_by_upper_case(self): type_ = "_test-srvc-type._tcp.local." @@ -100,10 +100,10 @@ def test_lookups_lower_case_by_upper_case(self): ) registry = r.ServiceRegistry() - registry.add(info) + registry.async_add(info) - assert registry.get_service_infos() == [info] - assert registry.get_info_name(registration_name.upper()) == info - assert registry.get_infos_type(type_.upper()) == [info] - assert registry.get_infos_server("ASH-2.local.") == [info] - assert registry.get_types() == [type_] + assert registry.async_get_service_infos() == [info] + assert registry.async_get_info_name(registration_name.upper()) == info + assert registry.async_get_infos_type(type_.upper()) == [info] + assert registry.async_get_infos_server("ASH-2.local.") == [info] + assert registry.async_get_types() == [type_] diff --git a/tests/services/test_types.py b/tests/services/test_types.py index d14a8b25..f4206cf4 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -50,7 +50,7 @@ def test_integration_with_listener(self): "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - zeroconf_registrar.registry.add(info) + zeroconf_registrar.registry.async_add(info) try: with patch.object( zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False @@ -87,7 +87,7 @@ def test_integration_with_listener_v6_records(self): "ash-2.local.", addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) - zeroconf_registrar.registry.add(info) + zeroconf_registrar.registry.async_add(info) try: with patch.object( zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False @@ -124,7 +124,7 @@ def test_integration_with_listener_ipv6(self): "ash-2.local.", addresses=[socket.inet_pton(socket.AF_INET6, addr)], ) - zeroconf_registrar.registry.add(info) + zeroconf_registrar.registry.async_add(info) try: with patch.object( zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False @@ -160,7 +160,7 @@ def test_integration_with_subtype_and_listener(self): "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")], ) - zeroconf_registrar.registry.add(info) + zeroconf_registrar.registry.async_add(info) try: with patch.object( zeroconf_registrar.engine.protocols[0], "suppress_duplicate_packet", return_value=False diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 34709d85..9ec5e496 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -812,7 +812,7 @@ async def test_info_asking_default_is_asking_qm_questions_after_the_first_qu(): type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) - zeroconf_info.registry.add(info) + zeroconf_info.registry.async_add(info) # we are going to patch the zeroconf send to check query transmission old_send = zeroconf_info.async_send diff --git a/tests/test_core.py b/tests/test_core.py index fee9c79d..9a1a14f4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -335,11 +335,11 @@ def test_goodbye_all_services(): info = r.ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) + zc.registry.async_add(info) out = zc.generate_unregister_all_services() assert out is not None first_packet = out.packets() - zc.registry.add(info) + zc.registry.async_add(info) out2 = zc.generate_unregister_all_services() assert out2 is not None second_packet = out.packets() @@ -348,7 +348,7 @@ def test_goodbye_all_services(): # Verify the registery is empty out3 = zc.generate_unregister_all_services() assert out3 is None - assert zc.registry.get_service_infos() == [] + assert zc.registry.async_get_service_infos() == [] zc.close() @@ -438,9 +438,9 @@ def test_tc_bit_defers(): info3 = r.ServiceInfo( type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) - zc.registry.add(info2) - zc.registry.add(info3) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) protocol = zc.engine.protocols[0] now = r.current_time_millis() @@ -517,9 +517,9 @@ def test_tc_bit_defers_last_response_missing(): info3 = r.ServiceInfo( type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) - zc.registry.add(info2) - zc.registry.add(info3) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) protocol = zc.engine.protocols[0] now = r.current_time_millis() @@ -581,7 +581,7 @@ def test_tc_bit_defers_last_response_missing(): assert source_ip not in protocol._timers # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() diff --git a/tests/test_handlers.py b/tests/test_handlers.py index e90a74bd..d4144721 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -87,7 +87,7 @@ def _process_outgoing_packet(out): expected_ttl = None for _ in range(3): _process_outgoing_packet(zc.generate_service_query(info)) - zc.registry.add(info) + zc.registry.async_add(info) for _ in range(3): _process_outgoing_packet(zc.generate_service_broadcast(info, None)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 3 @@ -112,7 +112,7 @@ def _process_outgoing_packet(out): # unregister expected_ttl = 0 - zc.registry.remove(info) + zc.registry.async_remove(info) for _ in range(3): _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 @@ -121,7 +121,7 @@ def _process_outgoing_packet(out): expected_ttl = None for _ in range(3): _process_outgoing_packet(zc.generate_service_query(info)) - zc.registry.add(info) + zc.registry.async_add(info) # register service with custom TTL expected_ttl = const._DNS_HOST_TTL * 2 assert expected_ttl != const._DNS_HOST_TTL @@ -147,7 +147,7 @@ def _process_outgoing_packet(out): # unregister expected_ttl = 0 - zc.registry.remove(info) + zc.registry.async_remove(info) for _ in range(3): _process_outgoing_packet(zc.generate_service_broadcast(info, 0)) assert nbr_answers == 12 and nbr_additionals == 0 and nbr_authorities == 0 @@ -284,7 +284,7 @@ def test_any_query_for_ptr(): server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) - zc.registry.add(info) + zc.registry.async_add(info) _clear_cache(zc) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -297,7 +297,7 @@ def test_any_query_for_ptr(): assert multicast_out.answers[0][0].name == type_ assert multicast_out.answers[0][0].alias == registration_name # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() @@ -311,7 +311,7 @@ def test_aaaa_query(): server_name = "ash-2.local." ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1") info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address]) - zc.registry.add(info) + zc.registry.async_add(info) generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) @@ -322,7 +322,7 @@ def test_aaaa_query(): ) assert multicast_out.answers[0][0].address == ipv6_address # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() @@ -342,7 +342,7 @@ def test_a_and_aaaa_record_fate_sharing(): aaaa_record = info.dns_addresses(version=r.IPVersion.V6Only)[0] a_record = info.dns_addresses(version=r.IPVersion.V4Only)[0] - zc.registry.add(info) + zc.registry.async_add(info) # Test AAAA query generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -375,7 +375,7 @@ def test_a_and_aaaa_record_fate_sharing(): assert len(multicast_out.answers) == 1 assert len(multicast_out.additionals) == 1 # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() @@ -393,7 +393,7 @@ def test_unicast_response(): type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] ) # register - zc.registry.add(info) + zc.registry.async_add(info) _clear_cache(zc) # query @@ -420,7 +420,7 @@ def test_unicast_response(): assert has_srv and has_txt and has_a # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() @@ -535,7 +535,7 @@ def test_known_answer_supression(): info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) + zc.registry.async_add(info) now = current_time_millis() _clear_cache(zc) @@ -631,7 +631,7 @@ def test_known_answer_supression(): assert not multicast_out or not multicast_out.answers # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) zc.close() @@ -660,9 +660,9 @@ def test_multi_packet_known_answer_supression(): info3 = ServiceInfo( type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) - zc.registry.add(info2) - zc.registry.add(info3) + zc.registry.async_add(info) + zc.registry.async_add(info2) + zc.registry.async_add(info3) now = current_time_millis() _clear_cache(zc) @@ -683,9 +683,9 @@ def test_multi_packet_known_answer_supression(): assert unicast_out is None assert multicast_out is None # unregister - zc.registry.remove(info) - zc.registry.remove(info2) - zc.registry.remove(info3) + zc.registry.async_remove(info) + zc.registry.async_remove(info2) + zc.registry.async_remove(info3) zc.close() @@ -699,7 +699,7 @@ def test_known_answer_supression_service_type_enumeration_query(): info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) + zc.registry.async_add(info) type_2 = "_otherknown2._tcp.local." name = "knownname" @@ -709,7 +709,7 @@ def test_known_answer_supression_service_type_enumeration_query(): info2 = ServiceInfo( type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info2) + zc.registry.async_add(info2) now = current_time_millis() _clear_cache(zc) @@ -755,8 +755,8 @@ def test_known_answer_supression_service_type_enumeration_query(): assert not multicast_out or not multicast_out.answers # unregister - zc.registry.remove(info) - zc.registry.remove(info2) + zc.registry.async_remove(info) + zc.registry.async_remove(info2) zc.close() @@ -777,7 +777,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info) + zc.registry.async_add(info) type_2 = "_addtest2._tcp.local." name = "knownname" @@ -787,7 +787,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): info2 = ServiceInfo( type_2, registration_name2, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")] ) - zc.registry.add(info2) + zc.registry.async_add(info2) ptr_record = info.dns_pointer() @@ -888,7 +888,7 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): assert info2.dns_service() in unicast_out.additionals # unregister - zc.registry.remove(info) + zc.registry.async_remove(info) await aiozc.async_close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 1d3e6cba..5a7f9ff3 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -518,7 +518,7 @@ async def async_register_service( await self.async_wait_for_start() await self.async_check_service(info, allow_name_change, cooperating_responders) - self.registry.add(info) + self.registry.async_add(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) def update_service(self, info: ServiceInfo) -> None: @@ -534,7 +534,7 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: """Registers service information to the network with a default TTL. Zeroconf will then respond to requests for information for that service.""" - self.registry.update(info) + self.registry.async_update(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: @@ -587,18 +587,18 @@ def unregister_service(self, info: ServiceInfo) -> None: async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service.""" - self.registry.remove(info) + self.registry.async_remove(info) return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" - service_infos = self.registry.get_service_infos() + service_infos = self.registry.async_get_service_infos() if not service_infos: return None out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) for info in service_infos: self._add_broadcast_answer(out, info, 0) - self.registry.remove(service_infos) + self.registry.async_remove(service_infos) return out async def async_unregister_all_services(self) -> None: diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index f9bdab74..07b9ac83 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -185,7 +185,7 @@ def _add_service_type_enumeration_query_answers( https://datatracker.ietf.org/doc/html/rfc6763#section-9 """ - for stype in self.registry.get_types(): + for stype in self.registry.async_get_types(): dns_pointer = DNSPointer( _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_PTR, _CLASS_IN, _DNS_OTHER_TTL, stype, now ) @@ -196,7 +196,7 @@ def _add_pointer_answers( self, name: str, answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float ) -> None: """Answer PTR/ANY question.""" - for service in self.registry.get_infos_type(name): + for service in self.registry.async_get_infos_type(name): # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) @@ -216,7 +216,7 @@ def _add_address_answers( type_: int, ) -> None: """Answer A/AAAA/ANY question.""" - for service in self.registry.get_infos_server(name): + for service in self.registry.async_get_infos_server(name): answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() for dns_address in service.dns_addresses(created=now): @@ -247,7 +247,7 @@ def _answer_question( self._add_address_answers(question.name, answer_set, known_answers, now, type_) if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY): - service = self.registry.get_info_name(question.name) # type: ignore + service = self.registry.async_get_info_name(question.name) # type: ignore if service is not None: if type_ in (_TYPE_SRV, _TYPE_ANY): # Add recommended additional answers according to diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index ebf5abbb..4e64c8d7 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -20,7 +20,6 @@ USA """ -import threading from typing import Dict, List, Optional, Union @@ -31,9 +30,8 @@ class ServiceRegistry: """A registry to keep track of services. - This class exists to ensure services can - be safely added and removed with thread - safety. + The registry must only be accessed from + the event loop as it is not thread safe. """ def __init__( @@ -43,56 +41,43 @@ def __init__( self._services: Dict[str, ServiceInfo] = {} self.types: Dict[str, List] = {} self.servers: Dict[str, List] = {} - self._lock = threading.Lock() # add and remove services thread safe - def add(self, info: ServiceInfo) -> None: + def async_add(self, info: ServiceInfo) -> None: """Add a new service to the registry.""" - with self._lock: - self._add(info) + self._add(info) - def remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: + def async_remove(self, info: Union[List[ServiceInfo], ServiceInfo]) -> None: """Remove a new service from the registry.""" - infos = info if isinstance(info, list) else [info] + self._remove(info if isinstance(info, list) else [info]) - with self._lock: - self._remove(infos) - - def update(self, info: ServiceInfo) -> None: + def async_update(self, info: ServiceInfo) -> None: """Update new service in the registry.""" + self._remove([info]) + self._add(info) - with self._lock: - self._remove([info]) - self._add(info) - - def get_service_infos(self) -> List[ServiceInfo]: + def async_get_service_infos(self) -> List[ServiceInfo]: """Return all ServiceInfo.""" return list(self._services.values()) - def get_info_name(self, name: str) -> Optional[ServiceInfo]: + def async_get_info_name(self, name: str) -> Optional[ServiceInfo]: """Return all ServiceInfo for the name.""" return self._services.get(name.lower()) - def get_types(self) -> List[str]: + def async_get_types(self) -> List[str]: """Return all types.""" return list(self.types.keys()) - def get_infos_type(self, type_: str) -> List[ServiceInfo]: + def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: """Return all ServiceInfo matching type.""" - return self._get_by_index("types", type_) + return self._async_get_by_index("types", type_) - def get_infos_server(self, server: str) -> List[ServiceInfo]: + def async_get_infos_server(self, server: str) -> List[ServiceInfo]: """Return all ServiceInfo matching server.""" - return self._get_by_index("servers", server) + return self._async_get_by_index("servers", server) - def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: + def _async_get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" - # Since we do not get under a lock since it would be - # a performance issue, its possible - # the service can be unregistered during the get - # so we must check if info is None - return list( - filter(None, [self._services.get(name) for name in getattr(self, attr).get(key.lower(), [])[:]]) - ) + return [self._services[name] for name in getattr(self, attr).get(key.lower(), [])] def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" From b2a7a00f82d401066166776cecf0857ebbdb56ad Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 20:32:09 -1000 Subject: [PATCH 518/608] Update changelog for 0.33.0 (#912) --- README.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.rst b/README.rst index 7cefb379..0d726d76 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,21 @@ See examples directory for more. Changelog ========= +0.33.0 (Unreleased) +=================== + +This release eliminates all threading locks as all non-threadsafe operations +now happen in the event loop. + +Technically backwards incompatible: + +* Remove duplicate unregister_all_services code (#910) @bdraco + + Calling Zeroconf.close from same asyncio event loop zeroconf is running in + will now skip unregister_all_services and log a warning as this a blocking + operation and is not async safe and never has been. + + Use AsyncZeroconf instead, or for legacy code call async_unregister_all_services before Zeroconf.close 0.32.1 ====== From 38eb271c952e89260ecac6fac3e723f4206c4648 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 21:03:01 -1000 Subject: [PATCH 519/608] Switch periodic cleanup task to call_later (#913) - Simplifies AsyncEngine to avoid the long running task --- tests/test_core.py | 19 +++++++++++++++++++ zeroconf/_core.py | 33 +++++++++++++++++---------------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 9a1a14f4..fd45b1ee 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -74,6 +74,7 @@ async def test_reaper(): entries_with_cache = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) await asyncio.sleep(1.2) entries = list(itertools.chain(*(cache.entries_with_name(name) for name in cache.names()))) + assert zeroconf.cache.get(record_with_1s_ttl) is None await aiozc.async_close() assert not zeroconf.question_history.suppresses(question, now, other_known_answers) assert entries != original_entries @@ -82,6 +83,24 @@ async def test_reaper(): assert record_with_1s_ttl not in entries +@pytest.mark.asyncio +async def test_reaper_aborts_when_done(): + """Ensure cache cleanup stops when zeroconf is done.""" + with patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10): + assert _core._CACHE_CLEANUP_INTERVAL == 10 + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zeroconf = aiozc.zeroconf + record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a') + record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b') + zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl]) + assert zeroconf.cache.get(record_with_10s_ttl) is not None + assert zeroconf.cache.get(record_with_1s_ttl) is not None + await aiozc.async_close() + await asyncio.sleep(1.2) + assert zeroconf.cache.get(record_with_10s_ttl) is not None + assert zeroconf.cache.get(record_with_1s_ttl) is not None + + class Framework(unittest.TestCase): def test_launch_and_close(self): rv = r.Zeroconf(interfaces=r.InterfaceChoice.All) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 5a7f9ff3..2909aa36 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -98,7 +98,7 @@ def __init__( self.senders: List[asyncio.DatagramTransport] = [] self._listen_socket = listen_socket self._respond_sockets = respond_sockets - self._cache_cleanup_task: Optional[asyncio.Task] = None + self._cleanup_timer: Optional[asyncio.TimerHandle] = None self._running_event: Optional[asyncio.Event] = None def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None: @@ -110,8 +110,10 @@ def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[thr async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: """Set up the instance.""" assert self.loop is not None + self._cleanup_timer = self.loop.call_later( + millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup + ) await self._async_create_endpoints() - self._cache_cleanup_task = self.loop.create_task(self._async_cache_cleanup()) assert self._running_event is not None self._running_event.set() if loop_thread_ready: @@ -142,26 +144,25 @@ async def _async_create_endpoints(self) -> None: if s in sender_sockets: self.senders.append(cast(asyncio.DatagramTransport, transport)) - async def _async_cache_cleanup(self) -> None: + def _async_cache_cleanup(self) -> None: """Periodic cache cleanup.""" - while not self.zc.done: - now = current_time_millis() - self.zc.question_history.async_expire(now) - self.zc.record_manager.async_updates( - now, [RecordUpdate(record, None) for record in self.zc.cache.async_expire(now)] - ) - self.zc.record_manager.async_updates_complete() - await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL)) + now = current_time_millis() + self.zc.question_history.async_expire(now) + self.zc.record_manager.async_updates( + now, [RecordUpdate(record, None) for record in self.zc.cache.async_expire(now)] + ) + self.zc.record_manager.async_updates_complete() + assert self.loop is not None + self._cleanup_timer = self.loop.call_later( + millis_to_seconds(_CACHE_CLEANUP_INTERVAL), self._async_cache_cleanup + ) async def _async_close(self) -> None: """Cancel and wait for the cleanup task to finish.""" self._async_shutdown() - if self._cache_cleanup_task: - self._cache_cleanup_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._cache_cleanup_task - self._cache_cleanup_task = None await asyncio.sleep(0) # flush out any call soons + assert self._cleanup_timer is not None + self._cleanup_timer.cancel() def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" From aa7108481235cc018600d096b093c785447d8769 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 16 Jul 2021 21:27:19 -1000 Subject: [PATCH 520/608] Remove Zeroconf.wait as its now unused in the codebase (#914) --- tests/services/test_browser.py | 4 ++-- tests/test_updates.py | 5 +++-- zeroconf/_core.py | 10 ---------- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 26684e09..292dee25 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -371,7 +371,7 @@ def mock_incoming_msg( zeroconf, mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1], 120), ) - zeroconf.wait(100) + time.sleep(0.1) called_with_refresh_time_check = False @@ -693,7 +693,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.register_service(info_service) - zc.wait(1) + time.sleep(0.001) browser.cancel() diff --git a/tests/test_updates.py b/tests/test_updates.py index b1d7f1b7..ecdf89d4 100644 --- a/tests/test_updates.py +++ b/tests/test_updates.py @@ -1,10 +1,11 @@ #!/usr/bin/env python -""" Unit tests for zeroconf._services. """ +""" Unit tests for zeroconf._updates. """ import logging import socket +import time from threading import Event import pytest @@ -77,7 +78,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): zc.register_service(info_service) - zc.wait(1) + time.sleep(0.001) browser.cancel() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 2909aa36..f605ab13 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -21,8 +21,6 @@ """ import asyncio -import concurrent.futures -import contextlib import itertools import random import socket @@ -423,14 +421,6 @@ def done(self) -> bool: def listeners(self) -> List[RecordUpdateListener]: return self.record_manager.listeners - def wait(self, timeout: float) -> None: - """Calling task waits for a given number of milliseconds or until notified.""" - assert self.loop is not None - with contextlib.suppress(concurrent.futures.TimeoutError): - asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result( - millis_to_seconds(timeout) - ) - async def async_wait(self, timeout: float) -> None: """Calling task waits for a given number of milliseconds or until notified.""" assert self.notify_event is not None From b6eaf7249f386f573b0876204ccfdfa02ee9ac5b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 17 Jul 2021 22:50:23 -1000 Subject: [PATCH 521/608] Reduce complexity of DNSRecord (#915) - Use constants for calculations in is_expired/is_stale/is_recent --- tests/services/test_info.py | 3 +-- zeroconf/_dns.py | 32 ++++++++++---------------------- zeroconf/const.py | 3 --- 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 8ac8beda..2060767f 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -187,8 +187,7 @@ def test_service_info_rejects_expired_records(self): ttl, b'\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==', ) - expired_record.created = 1000 - expired_record._expiration_time = 1000 + expired_record.set_created_ttl(1000, 1) info.update_record(zc, now, expired_record) assert info.properties[b"ci"] == b"2" zc.close() diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 7e06cff4..0f7a5e11 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -31,9 +31,6 @@ _CLASSES, _CLASS_MASK, _CLASS_UNIQUE, - _EXPIRE_FULL_TIME_PERCENT, - _EXPIRE_STALE_TIME_PERCENT, - _RECENT_TIME_PERCENT, _TYPES, _TYPE_ANY, ) @@ -45,6 +42,11 @@ _BASE_MAX_SIZE = _LEN_SHORT + _LEN_SHORT + _LEN_INT + _LEN_SHORT # type # class # ttl # length _NAME_COMPRESSION_MIN_SIZE = _LEN_BYTE * 2 +_EXPIRE_FULL_TIME_MS = 1000 +_EXPIRE_STALE_TIME_MS = 500 +_RECENT_TIME_MS = 250 + + if TYPE_CHECKING: # https://github.com/PyCQA/pylint/issues/3525 from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import @@ -154,7 +156,7 @@ class DNSRecord(DNSEntry): """A DNS record - like a DNS entry, but has a TTL""" - __slots__ = ('ttl', 'created', '_expiration_time', '_stale_time', '_recent_time') + __slots__ = ('ttl', 'created') # TODO: Switch to just int ttl def __init__( @@ -163,9 +165,6 @@ def __init__( super().__init__(name, type_, class_) self.ttl = ttl self.created = created or current_time_millis() - self._expiration_time: Optional[float] = None - self._stale_time: Optional[float] = None - self._recent_time: Optional[float] = None def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use """Abstract method""" @@ -189,27 +188,19 @@ def get_expiration_time(self, percent: int) -> float: # TODO: Switch to just int here def get_remaining_ttl(self, now: float) -> Union[int, float]: """Returns the remaining TTL in seconds.""" - if self._expiration_time is None: - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - return max(0, millis_to_seconds(self._expiration_time - now)) + return max(0, millis_to_seconds((self.created + (_EXPIRE_FULL_TIME_MS * self.ttl)) - now)) def is_expired(self, now: float) -> bool: """Returns true if this record has expired.""" - if self._expiration_time is None: - self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT) - return self._expiration_time <= now + return self.created + (_EXPIRE_FULL_TIME_MS * self.ttl) <= now def is_stale(self, now: float) -> bool: """Returns true if this record is at least half way expired.""" - if self._stale_time is None: - self._stale_time = self.get_expiration_time(_EXPIRE_STALE_TIME_PERCENT) - return self._stale_time <= now + return self.created + (_EXPIRE_STALE_TIME_MS * self.ttl) <= now def is_recent(self, now: float) -> bool: """Returns true if the record more than one quarter of its TTL remaining.""" - if self._recent_time is None: - self._recent_time = self.get_expiration_time(_RECENT_TIME_PERCENT) - return self._recent_time > now + return self.created + (_RECENT_TIME_MS * self.ttl) > now def reset_ttl(self, other: 'DNSRecord') -> None: """Sets this record's TTL and created time to that of @@ -220,9 +211,6 @@ def set_created_ttl(self, created: float, ttl: Union[float, int]) -> None: """Set the created and ttl of a record.""" self.created = created self.ttl = ttl - self._expiration_time = None - self._stale_time = None - self._recent_time = None def write(self, out: 'DNSOutgoing') -> None: # pylint: disable=no-self-use """Abstract method""" diff --git a/zeroconf/const.py b/zeroconf/const.py index 76a75dbd..27dc817f 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -145,10 +145,7 @@ _HAS_ONLY_A_TO_Z_NUM_HYPHEN_UNDERSCORE = re.compile(r'^[A-Za-z0-9\-\_]+$') _HAS_ASCII_CONTROL_CHARS = re.compile(r'[\x00-\x1f\x7f]') -_EXPIRE_FULL_TIME_PERCENT = 100 -_EXPIRE_STALE_TIME_PERCENT = 50 _EXPIRE_REFRESH_TIME_PERCENT = 75 -_RECENT_TIME_PERCENT = 25 _LOCAL_TRAILER = '.local.' _TCP_PROTOCOL_LOCAL_TRAILER = '._tcp.local.' From 919b096d6260a4f9f4306b9b4dddb5b026b49462 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 18:07:46 -1000 Subject: [PATCH 522/608] Let connection_lost close the underlying socket (#918) - The socket was closed during shutdown before asyncio's connection_lost handler had a chance to close it which resulted in a traceback on win32. - Fixes #917 --- zeroconf/_core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index f605ab13..37b72d59 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -166,8 +166,6 @@ def _async_shutdown(self) -> None: """Shutdown transports and sockets.""" for transport in itertools.chain(self.senders, self.readers): transport.close() - for s in self._respond_sockets: - s.close() def close(self) -> None: """Close from sync context.""" @@ -328,6 +326,9 @@ def error_received(self, exc: Exception) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle connection lost.""" + class Zeroconf(QuietLogger): From 96be9618ede3c941e23cb23398b9aed11bed1ffa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 18:11:24 -1000 Subject: [PATCH 523/608] Update changelog for 0.33.0 release (#919) --- README.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 0d726d76..649d3bd8 100644 --- a/README.rst +++ b/README.rst @@ -146,9 +146,17 @@ Changelog This release eliminates all threading locks as all non-threadsafe operations now happen in the event loop. +* Let connection_lost close the underlying socket (#918) @bdraco + + The socket was closed during shutdown before asyncio's connection_lost + handler had a chance to close it which resulted in a traceback on + windows. + + Fixed #917 + Technically backwards incompatible: -* Remove duplicate unregister_all_services code (#910) @bdraco +* Removed duplicate unregister_all_services code (#910) @bdraco Calling Zeroconf.close from same asyncio event loop zeroconf is running in will now skip unregister_all_services and log a warning as this a blocking From 2e0000252f0aecad8b62a649128326a6528b6824 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 18:29:18 -1000 Subject: [PATCH 524/608] Add support for bump2version (#920) --- requirements-dev.txt | 1 + setup.cfg | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index dc2f21de..3035d59d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ autopep8 black;implementation_name=="cpython" +bump2version coveralls coverage # Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 - is fixed diff --git a/setup.cfg b/setup.cfg index e208561b..cfcfb4b0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,13 @@ +[bumpversion] +current_version = 0.32.1 +commit = True +tag = True +tag_name = {new_version} + +[bumpversion:file:zeroconf/__init__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + [tool:pytest] testpaths = tests From b0b23f96d3b33a627a0d071557a36af97a65dae4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 18:50:47 -1000 Subject: [PATCH 525/608] Fix examples/async_registration.py attaching to the correct loop (#921) --- examples/async_registration.py | 49 ++++++++++++++++------------------ 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/examples/async_registration.py b/examples/async_registration.py index 53d14ce1..c3aab326 100644 --- a/examples/async_registration.py +++ b/examples/async_registration.py @@ -5,27 +5,32 @@ import asyncio import logging import socket -import time -from typing import List +from typing import List, Optional from zeroconf import IPVersion from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf -async def register_services(infos: List[AsyncServiceInfo]) -> None: - tasks = [aiozc.async_register_service(info) for info in infos] - background_tasks = await asyncio.gather(*tasks) - await asyncio.gather(*background_tasks) - - -async def unregister_services(infos: List[AsyncServiceInfo]) -> None: - tasks = [aiozc.async_unregister_service(info) for info in infos] - background_tasks = await asyncio.gather(*tasks) - await asyncio.gather(*background_tasks) +class AsyncRunner: + def __init__(self, ip_version: IPVersion) -> None: + self.ip_version = ip_version + self.aiozc: Optional[AsyncZeroconf] = None + async def register_services(self, infos: List[AsyncServiceInfo]) -> None: + self.aiozc = AsyncZeroconf(ip_version=self.ip_version) + tasks = [self.aiozc.async_register_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + print("Finished registration, press Ctrl-C to exit...") + while True: + await asyncio.sleep(1) -async def close_aiozc(aiozc: AsyncZeroconf) -> None: - await aiozc.async_close() + async def unregister_services(self, infos: List[AsyncServiceInfo]) -> None: + assert self.aiozc is not None + tasks = [self.aiozc.async_unregister_service(info) for info in infos] + background_tasks = await asyncio.gather(*tasks) + await asyncio.gather(*background_tasks) + await self.aiozc.async_close() if __name__ == '__main__': @@ -60,18 +65,10 @@ async def close_aiozc(aiozc: AsyncZeroconf) -> None: ) ) - print("Registration of 250 services, press Ctrl-C to exit...") - aiozc = AsyncZeroconf(ip_version=ip_version) + print("Registration of 250 services...") loop = asyncio.get_event_loop() - loop.run_until_complete(register_services(infos)) - print("Registration complete.") + runner = AsyncRunner(ip_version) try: - while True: - time.sleep(0.1) + loop.run_until_complete(runner.register_services(infos)) except KeyboardInterrupt: - pass - finally: - print("Unregistering...") - loop.run_until_complete(unregister_services(infos)) - print("Unregistration complete.") - loop.run_until_complete(close_aiozc(aiozc)) + loop.run_until_complete(runner.unregister_services(infos)) From e4a96550398c408c3e1e6944662cc3093db912a7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 19:04:32 -1000 Subject: [PATCH 526/608] Update changelog for 0.33.0 release (#922) --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 649d3bd8..b87cfb19 100644 --- a/README.rst +++ b/README.rst @@ -140,8 +140,8 @@ See examples directory for more. Changelog ========= -0.33.0 (Unreleased) -=================== +0.33.0 +====== This release eliminates all threading locks as all non-threadsafe operations now happen in the event loop. From cfb28aaf134e566d8a89b397967d1ad1ec66de35 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 19:06:30 -1000 Subject: [PATCH 527/608] =?UTF-8?q?Bump=20version:=200.32.1=20=E2=86=92=20?= =?UTF-8?q?0.33.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 9 ++++----- zeroconf/__init__.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index cfcfb4b0..53810dcd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.32.1 +current_version = 0.33.0 commit = True tag = True tag_name = {new_version} @@ -13,9 +13,9 @@ testpaths = tests [flake8] show-source = 1 -application-import-names=zeroconf -max-line-length=110 -ignore=E203,W503,N818 +application-import-names = zeroconf +max-line-length = 110 +ignore = E203,W503,N818 [mypy] ignore_missing_imports = true @@ -28,7 +28,6 @@ warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true warn_return_any = true -# TODO: disallow untyped calls and defs once we have full type hint coverage disallow_untyped_calls = false disallow_untyped_defs = true diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 666914b2..a5ba5274 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.32.1' +__version__ = '0.33.0' __license__ = 'LGPL' From ed80333896c0710857cc46b5af4d7ba3a81e07c8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 22:47:25 -1000 Subject: [PATCH 528/608] Update changelog for 0.33.1 (#924) - Fixes overly restrictive directory permissions reported in #923 --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index b87cfb19..9c563dbb 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,13 @@ See examples directory for more. Changelog ========= +0.33.1 +====== + +* Version number change only with less restrictive directory permissions + + Fixed #923 + 0.33.0 ====== From 6774de3e7f8b461ccb83675bbb05d47949df487b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 18 Jul 2021 22:48:18 -1000 Subject: [PATCH 529/608] =?UTF-8?q?Bump=20version:=200.33.0=20=E2=86=92=20?= =?UTF-8?q?0.33.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 53810dcd..18312c8f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.33.0 +current_version = 0.33.1 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index a5ba5274..7b5d0259 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.33.0' +__version__ = '0.33.1' __license__ = 'LGPL' From 1247acd2e6f6154a4e5f2e27a820c55329391d8e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 21 Jul 2021 19:32:15 -1000 Subject: [PATCH 530/608] Remove some pylint workarounds (#925) --- zeroconf/_dns.py | 3 +-- zeroconf/_handlers.py | 3 +-- zeroconf/_logger.py | 2 +- zeroconf/_protocol.py | 3 +-- zeroconf/_services/__init__.py | 3 +-- zeroconf/_services/browser.py | 3 +-- zeroconf/_services/info.py | 3 +-- zeroconf/_updates.py | 3 +-- zeroconf/_utils/net.py | 2 +- 9 files changed, 9 insertions(+), 16 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 0f7a5e11..5b211060 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -48,8 +48,7 @@ if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from ._protocol import DNSIncoming, DNSOutgoing # pylint: disable=cyclic-import + from ._protocol import DNSIncoming, DNSOutgoing @enum.unique diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 07b9ac83..5e9e6e22 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -48,8 +48,7 @@ ) if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from ._core import Zeroconf # pylint: disable=cyclic-import + from ._core import Zeroconf _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py index 78c21148..e779a765 100644 --- a/zeroconf/_logger.py +++ b/zeroconf/_logger.py @@ -24,7 +24,7 @@ import sys from typing import Any, Dict, Union, cast -log = logging.getLogger(__name__.split('.')[0]) +log = logging.getLogger(__name__.split('.', maxsplit=1)[0]) log.addHandler(logging.NullHandler()) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index f987ce2e..7ca27799 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -51,8 +51,7 @@ if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from ._cache import DNSCache # pylint: disable=cyclic-import + from ._cache import DNSCache class DNSMessage: diff --git a/zeroconf/_services/__init__.py b/zeroconf/_services/__init__.py index 3759f1ec..5b9fbf01 100644 --- a/zeroconf/_services/__init__.py +++ b/zeroconf/_services/__init__.py @@ -25,8 +25,7 @@ if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from .._core import Zeroconf # pylint: disable=cyclic-import + from .._core import Zeroconf @enum.unique diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index fecf35c9..51f2c8d5 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -65,8 +65,7 @@ } if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from .._core import Zeroconf # pylint: disable=cyclic-import + from .._core import Zeroconf _QuestionWithKnownAnswers = Dict[DNSQuestion, Set[DNSPointer]] diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index d1bf17e9..cede3877 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -53,8 +53,7 @@ if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from .._core import Zeroconf # pylint: disable=cyclic-import + from .._core import Zeroconf def instance_name_from_service_info(info: "ServiceInfo") -> str: diff --git a/zeroconf/_updates.py b/zeroconf/_updates.py index d7ad56c1..bc7dcab5 100644 --- a/zeroconf/_updates.py +++ b/zeroconf/_updates.py @@ -27,8 +27,7 @@ if TYPE_CHECKING: - # https://github.com/PyCQA/pylint/issues/3525 - from ._core import Zeroconf # pylint: disable=cyclic-import + from ._core import Zeroconf class RecordUpdate(NamedTuple): diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 937dc116..d7f127ec 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -211,7 +211,7 @@ def set_mdns_port_socket_options_for_ip_version( s.setsockopt(_IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) -def new_socket( # pylint: disable=too-many-branches +def new_socket( bind_addr: Union[Tuple[str], Tuple[str, int, int]], port: int = _MDNS_PORT, ip_version: IPVersion = IPVersion.V4Only, From 73e3d1865f4167e7c9f7c23ec4cc7ebfac40f512 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Jul 2021 09:56:13 -0500 Subject: [PATCH 531/608] Skip ipv6 interfaces that return ENODEV (#930) --- tests/utils/test_net.py | 8 ++++++++ zeroconf/_utils/net.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 399bd6ac..238e709c 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -182,6 +182,14 @@ def test_add_multicast_member(): with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENOPROTOOPT, None)): assert netutils.add_multicast_member(sock, interface) is False + # ENODEV should raise for ipv4 + with pytest.raises(OSError), patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)): + netutils.add_multicast_member(sock, interface) is False + + # ENODEV should return False for ipv6 + with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)): + assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False + # No error should return True with patch("socket.socket.setsockopt"): assert netutils.add_multicast_member(sock, interface) is True diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index d7f127ec..3aafe768 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -291,6 +291,13 @@ def add_multicast_member( interface, ) return False + if is_v6 and _errno == errno.ENODEV: + log.info( + 'Address in use when adding %s to multicast group, ' + 'it is expected to happen when the device does not have ipv6', + interface, + ) + return False raise return True From 97e0b669be60f716e45e963f1bcfcd35b7213626 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Jul 2021 10:07:59 -0500 Subject: [PATCH 532/608] Handle duplicate goodbye answers in the same packet (#928) - Solves an exception being thrown when we tried to remove the known answer from the cache when the second goodbye answer in the same packet was processed - We previously swallowed all exceptions on cache removal so this was not visible until 0.32.x which removed the broad exception catch Fixes #926 --- tests/test_handlers.py | 32 ++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 4 ++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index d4144721..ebe19f41 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1122,3 +1122,35 @@ async def test_guard_against_low_ptr_ttl(): assert incoming_answer_normal.ttl == const._DNS_OTHER_TTL assert zc.cache.async_get_unique(good_bye_answer) is None await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_duplicate_goodbye_answers_in_packet(): + """Ensure we do not throw an exception when there are duplicate goodbye records in a packet.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + zc = aiozc.zeroconf + answer_with_normal_ttl = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'host.local.', + ) + good_bye_answer = r.DNSPointer( + "myservicelow_tcp._tcp.local.", + const._TYPE_PTR, + const._CLASS_IN | const._CLASS_UNIQUE, + 0, + 'host.local.', + ) + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(answer_with_normal_ttl, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + + response = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + response.add_answer_at_time(good_bye_answer, 0) + response.add_answer_at_time(good_bye_answer, 0) + incoming = r.DNSIncoming(response.packets()[0]) + zc.record_manager.async_updates_from_response(incoming) + await aiozc.async_close() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 5e9e6e22..29ea0b6b 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -331,7 +331,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: updates: List[RecordUpdate] = [] address_adds: List[DNSAddress] = [] other_adds: List[DNSRecord] = [] - removes: List[DNSRecord] = [] + removes: Set[DNSRecord] = set() now = msg.now unique_types: Set[Tuple[str, int, int]] = set() @@ -355,7 +355,7 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None: # expired and exists in the cache elif maybe_entry is not None: updates.append(RecordUpdate(record, maybe_entry)) - removes.append(record) + removes.add(record) if unique_types: self._async_mark_unique_cached_records_older_than_1s_to_expire(unique_types, msg.answers, now) From c80b5f7253e521928d6f7e54681675be59371c6c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Jul 2021 10:11:24 -0500 Subject: [PATCH 533/608] Update changelog for 0.33.2 (#931) --- README.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.rst b/README.rst index 9c563dbb..b53656be 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,18 @@ See examples directory for more. Changelog ========= +0.33.2 +====== + +* Handle duplicate goodbye answers in the same packet (#928) @bdraco + + Solves an exception being thrown when we tried to remove the known answer + from the cache when the second goodbye answer in the same packet was processed + + Fixed #926 + +* Skip ipv6 interfaces that return ENODEV (#930) @bdraco + 0.33.1 ====== From 4d30c25fe57425bcae36a539006e44941ef46e2c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Jul 2021 05:22:17 -1000 Subject: [PATCH 534/608] =?UTF-8?q?Bump=20version:=200.33.1=20=E2=86=92=20?= =?UTF-8?q?0.33.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 18312c8f..65b769fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.33.1 +current_version = 0.33.2 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 7b5d0259..69166393 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.33.1' +__version__ = '0.33.2' __license__ = 'LGPL' From 319992bb093d9b965976bad724512d9bcd05aca7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 5 Aug 2021 16:17:56 -0500 Subject: [PATCH 535/608] Provide sockname when logging a protocol error (#935) --- tests/test_logger.py | 21 +++++++++++++++++++++ zeroconf/_core.py | 30 ++++++++++++++++++++---------- zeroconf/_logger.py | 11 +++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/tests/test_logger.py b/tests/test_logger.py index cedda7e9..2d8bbb08 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -61,3 +61,24 @@ def test_log_exception_warning(): assert not mock_log_warning.mock_calls assert mock_log_debug.mock_calls + + +def test_log_exception_once(): + """Test we only log with warning level once.""" + quiet_logger = QuietLogger() + exc = Exception() + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_once(exc, "the exceptional exception warning") + + assert mock_log_warning.mock_calls + assert not mock_log_debug.mock_calls + + with patch("zeroconf._logger.log.warning") as mock_log_warning, patch( + "zeroconf._logger.log.debug" + ) as mock_log_debug: + quiet_logger.log_exception_once(exc, "the exceptional exception warning") + + assert not mock_log_warning.mock_calls + assert mock_log_debug.mock_calls diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 37b72d59..b2320601 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -226,10 +226,10 @@ def datagram_received( if self.suppress_duplicate_packet(data, now): # Guard against duplicate packets log.debug( - 'Ignoring duplicate message received from %r:%r (socket %d) (%d bytes) as [%r]', + 'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]', addr, port, - self.transport.get_extra_info('socket').fileno(), + self._socket_description, len(data), data, ) @@ -249,20 +249,20 @@ def datagram_received( msg = DNSIncoming(data, scope, now) if msg.valid: log.debug( - 'Received from %r:%r (socket %d): %r (%d bytes) as [%r]', + 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', addr, port, - self.transport.get_extra_info('socket').fileno(), + self._socket_description, msg, len(data), data, ) else: log.debug( - 'Received from %r:%r (socket %d): (%d bytes) [%r]', + 'Received from %r:%r [socket %s]: (%d bytes) [%r]', addr, port, - self.transport.get_extra_info('socket').fileno(), + self._socket_description, len(data), data, ) @@ -316,12 +316,22 @@ def _respond_query( self.zc.handle_assembled_query(packets, addr, port, v6_flow_scope) + @property + def _socket_description(self) -> str: + """A human readable description of the socket.""" + assert self.transport is not None + fileno = self.transport.get_extra_info('socket').fileno() + sockname = self.transport.get_extra_info('sockname') + return f"{fileno} ({sockname})" + def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" - assert self.transport is not None - self.log_warning_once( - 'Error with socket %d: %s', self.transport.get_extra_info('socket').fileno(), exc - ) + # We preformat the message string with the socket as we want + # log_exception_once to log a warrning message once PER EACH + # different socket in case there are problems with multiple + # sockets + msg_str = f"Error with socket {self._socket_description}): %s" + self.log_exception_once(exc, msg_str, exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) diff --git a/zeroconf/_logger.py b/zeroconf/_logger.py index e779a765..932d1a2f 100644 --- a/zeroconf/_logger.py +++ b/zeroconf/_logger.py @@ -61,3 +61,14 @@ def log_warning_once(cls, *args: Any) -> None: logger = log.debug cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 logger(*args) + + @classmethod + def log_exception_once(cls, exc: Exception, *args: Any) -> None: + msg_str = args[0] + if msg_str not in cls._seen_logs: + cls._seen_logs[msg_str] = 0 + logger = log.warning + else: + logger = log.debug + cls._seen_logs[msg_str] = cast(int, cls._seen_logs[msg_str]) + 1 + logger(*args, exc_info=exc) From 5682a4c3c89043bf8a10e79232933ada5ab71972 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 5 Aug 2021 16:18:14 -0500 Subject: [PATCH 536/608] Add support for forward dns compression pointers (#934) - nslookup supports these and some implementations (likely avahi) will generate them - Careful attention was given to make sure we detect loops and do not create anti-patterns described in https://github.com/Forescout/namewreck/blob/main/rfc/draft-dashevskyi-dnsrr-antipatterns-00.txt Fixes https://github.com/home-assistant/core/issues/53937 Fixes https://github.com/home-assistant/core/issues/46985 Fixes https://github.com/home-assistant/core/issues/53668 Fixes #308 --- tests/test_protocol.py | 226 +++++++++++++++++++++++++++++++++++++++++ zeroconf/_protocol.py | 202 +++++++++++++++++++++--------------- 2 files changed, 346 insertions(+), 82 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 75a69d5e..706afdd3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -632,6 +632,10 @@ def test_dns_compression_rollback_for_corruption(): # ensure there is no corruption with the dns compression incoming = r.DNSIncoming(packet) assert incoming.valid is True + assert ( + len(incoming.answers) + == incoming.num_answers + incoming.num_authorities + incoming.num_additionals + ) def test_tc_bit_in_query_packet(): @@ -761,3 +765,225 @@ def test_records_same_packet_share_fate(): first_time = dnsin.answers[0].created for answer in dnsin.answers: assert answer.created == first_time + + +def test_dns_compression_invalid_skips_bad_name_compress_in_question(): + """Test our wire parser can skip bad compression in questions.""" + packet = ( + b'\x00\x00\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x11homeassistant1128\x05l' + b'ocal\x00\x00\xff\x00\x014homeassistant1128 [534a4794e5ed41879ecf012252d3e02' + b'a]\x0c_workstation\x04_tcp\xc0\x1e\x00\xff\x00\x014homeassistant1127 [534a47' + b'94e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1123 [534a479' + b'4e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x014homeassistant1118 [534a4794' + b'e5ed41879ecf012252d3e02a]\xc0^\x00\xff\x00\x01\xc0\x0c\x00\x01\x80' + b'\x01\x00\x00\x00x\x00\x04\xc0\xa8<\xc3\xc0v\x00\x10\x80\x01\x00\x00\x00' + b'x\x00\x01\x00\xc0v\x00!\x80\x01\x00\x00\x00x\x00\x1f\x00\x00\x00\x00' + b'\x00\x00\x11homeassistant1127\x05local\x00\xc0\xb1\x00\x10\x80' + b'\x01\x00\x00\x00x\x00\x01\x00\xc0\xb1\x00!\x80\x01\x00\x00\x00x\x00\x1f' + b'\x00\x00\x00\x00\x00\x00\x11homeassistant1123\x05local\x00\xc0)\x00\x10\x80' + b'\x01\x00\x00\x00x\x00\x01\x00\xc0)\x00!\x80\x01\x00\x00\x00x\x00\x1f' + b'\x00\x00\x00\x00\x00\x00\x11homeassistant1128\x05local\x00' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.questions) == 4 + + +def test_dns_compression_all_invalid(): + """Test our wire parser can skip all invalid data.""" + packet = ( + b'\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00!roborock-vacuum-s5e_miio416' + b'112328\x00\x00/\x80\x01\x00\x00\x00x\x00\t\xc0P\x00\x05@\x00\x00\x00\x00' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.questions) == 0 + assert len(parsed.answers) == 0 + + +def test_invalid_next_name_ignored(): + """Test our wire parser does not throw an an invalid next name. + + The RFC states it should be ignored when used with mDNS. + """ + packet = ( + b'\x00\x00\x00\x00\x00\x01\x00\x02\x00\x00\x00\x00\x07Android\x05local\x00\x00' + b'\xff\x00\x01\xc0\x0c\x00/\x00\x01\x00\x00\x00x\x00\x08\xc02\x00\x04@' + b'\x00\x00\x08\xc0\x0c\x00\x01\x00\x01\x00\x00\x00x\x00\x04\xc0\xa8X<' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.questions) == 1 + assert len(parsed.answers) == 2 + + +def test_dns_compression_invalid_skips_record(): + """Test our wire parser can skip records we do not know how to parse.""" + packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x04_hap\x04_tcp\x05local\x00\x00\x0c" + b"\x00\x01\x00\x00\x11\x94\x00\x16\x13eufy HomeBase2-2464\xc0\x0c\x04Eufy\xc0\x16\x00/" + b"\x80\x01\x00\x00\x00x\x00\x08\xc0\xa6\x00\x04@\x00\x00\x08\xc0'\x00/\x80\x01\x00\x00" + b"\x11\x94\x00\t\xc0'\x00\x05\x00\x00\x80\x00@\xc0=\x00\x01\x80\x01\x00\x00\x00x\x00\x04" + b"\xc0\xa8Dp\xc0'\x00!\x80\x01\x00\x00\x00x\x00\x08\x00\x00\x00\x00\xd1_\xc0=\xc0'\x00" + b"\x10\x80\x01\x00\x00\x11\x94\x00K\x04c#=1\x04ff=2\x14id=38:71:4F:6B:76:00\x08md=T8010" + b"\x06pv=1.1\x05s#=75\x04sf=1\x04ci=2\x0bsh=xaQk4g==" + ) + parsed = r.DNSIncoming(packet) + answer = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + assert answer in parsed.answers + + +def test_dns_compression_points_forward(): + """Test our wire parser can unpack nsec records with compression.""" + packet = ( + b"\x00\x00\x84\x00\x00\x00\x00\x07\x00\x00\x00\x00\x0eTV Beneden (2)" + b"\x10_androidtvremote\x04_tcp\x05local\x00\x00\x10\x80\x01\x00\x00\x11" + b"\x94\x00\x15\x14bt=D8:13:99:AC:98:F1\xc0\x0c\x00/\x80\x01\x00\x00\x11" + b"\x94\x00\t\xc0\x0c\x00\x05\x00\x00\x80\x00@\tAndroid-3\xc01\x00/\x80" + b"\x01\x00\x00\x00x\x00\x08\xc0\x9c\x00\x04@\x00\x00\x08\xc0l\x00\x01\x80" + b"\x01\x00\x00\x00x\x00\x04\xc0\xa8X\x0f\xc0\x0c\x00!\x80\x01\x00\x00\x00" + b"x\x00\x08\x00\x00\x00\x00\x19B\xc0l\xc0\x1b\x00\x0c\x00\x01\x00\x00\x11" + b"\x94\x00\x02\xc0\x0c\t_services\x07_dns-sd\x04_udp\xc01\x00\x0c\x00\x01" + b"\x00\x00\x11\x94\x00\x02\xc0\x1b" + ) + parsed = r.DNSIncoming(packet) + answer = r.DNSNsec( + 'TV Beneden (2)._androidtvremote._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'TV Beneden (2)._androidtvremote._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + assert answer in parsed.answers + + +def test_dns_compression_points_to_itself(): + """Test our wire parser does not loop forever when a compression pointer points to itself.""" + packet = ( + b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01" + b"\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0(\x00\x01\x80\x01\x00\x00\x00" + b"\x01\x00\x04\xc0\xa8\xd0\x06" + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_points_beyond_packet(): + """Test our wire parser does not fail when the compression pointer points beyond the packet.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01' + b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xe7\x0f\x00\x01\x80\x01\x00\x00' + b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_generic_failure(): + """Test our wire parser does not loop forever when dns compression is corrupt.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01' + b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00' + b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_label_length_attack(): + """Test our wire parser does not loop forever when the name exceeds 253 chars.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d' + b'\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x01d\x00\x00\x01\x80' + b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\xc0\x0c\x00\x01\x80\x01\x00\x00\x00' + b'\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + + +def test_label_compression_attack(): + """Test our wire parser does not loop forever when exceeding the maximum number of labels.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03atk\x00\x00\x01\x80' + b'\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03' + b'atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\x03atk\xc0' + b'\x0c\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x06' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 1 + + +def test_dns_compression_loop_attack(): + """Test our wire parser does not loop forever when dns compression is in a loop.""" + packet = ( + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07\x03atk\x03dns\x05loc' + b'al\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05\x04a' + b'tk2\x04dns2\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05' + b'\x04atk3\xc0\x10\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0' + b'\x05\x04atk4\x04dns5\xc0\x14\x00\x01\x80\x01\x00\x00\x00\x01\x00\x04\xc0' + b'\xa8\xd0\x05\x04atk5\x04dns2\xc0^\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05\xc0s\x00\x01\x80\x01\x00\x00\x00\x01\x00' + b'\x04\xc0\xa8\xd0\x05' + ) + parsed = r.DNSIncoming(packet) + assert len(parsed.answers) == 0 + + +def test_txt_after_invalid_nsec_name_still_usable(): + """Test that we can see the txt record after the invalid nsec record.""" + packet = ( + b'\x00\x00\x84\x00\x00\x00\x00\x06\x00\x00\x00\x00\x06_sonos\x04_tcp\x05loc' + b'al\x00\x00\x0c\x00\x01\x00\x00\x11\x94\x00\x15\x12Sonos-542A1BC9220E' + b'\xc0\x0c\x12Sonos-542A1BC9220E\xc0\x18\x00/\x80\x01\x00\x00\x00x\x00' + b'\x08\xc1t\x00\x04@\x00\x00\x08\xc0)\x00/\x80\x01\x00\x00\x11\x94\x00' + b'\t\xc0)\x00\x05\x00\x00\x80\x00@\xc0)\x00!\x80\x01\x00\x00\x00x' + b'\x00\x08\x00\x00\x00\x00\x05\xa3\xc0>\xc0>\x00\x01\x80\x01\x00\x00\x00x' + b'\x00\x04\xc0\xa8\x02:\xc0)\x00\x10\x80\x01\x00\x00\x11\x94\x01*2info=/api' + b'/v1/players/RINCON_542A1BC9220E01400/info\x06vers=3\x10protovers=1.24.1\nbo' + b'otseq=11%hhid=Sonos_rYn9K9DLXJe0f3LP9747lbvFvh;mhhid=Sonos_rYn9K9DLXJe0f3LP9' + b'747lbvFvh.Q45RuMaeC07rfXh7OJGm str: @@ -168,60 +177,77 @@ def read_others(self) -> None: for _ in range(n): domain = self.read_name() type_, class_, ttl, length = self.unpack(b'!HHiH') - rec: Optional[DNSRecord] = None - if type_ == _TYPE_A: - rec = DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) - elif type_ in (_TYPE_CNAME, _TYPE_PTR): - rec = DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) - elif type_ == _TYPE_TXT: - rec = DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) - elif type_ == _TYPE_SRV: - rec = DNSService( - domain, - type_, - class_, - ttl, - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_name(), - self.now, - ) - elif type_ == _TYPE_HINFO: - rec = DNSHinfo( + end = self.offset + length + rec = None + try: + rec = self.read_record(domain, type_, class_, ttl, length) + except DECODE_EXCEPTIONS: + # Skip records that fail to decode if we know the length + # If the packet is really corrupt read_name and the unpack + # above would fail and hit the exception catch in read_others + self.offset = end + log.debug( + 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', domain, - type_, - class_, - ttl, - self.read_character_string().decode('utf-8'), - self.read_character_string().decode('utf-8'), - self.now, - ) - elif type_ == _TYPE_AAAA: - rec = DNSAddress( - domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id + _TYPES.get(type_, type_), + self.offset, + self.data, + exc_info=True, ) - elif type_ == _TYPE_NSEC: - name_start = self.offset - name = self.read_name() - rec = DNSNsec( - domain, - type_, - class_, - ttl, - name, - self.read_bitmap(name_start + length), - self.now, - ) - else: - # Try to ignore types we don't know about - # Skip the payload for the resource record so the next - # records can be parsed correctly - self.offset += length - if rec is not None: self.answers.append(rec) + def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: + """Read known records types and skip unknown ones.""" + if type_ == _TYPE_A: + return DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) + if type_ in (_TYPE_CNAME, _TYPE_PTR): + return DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) + if type_ == _TYPE_TXT: + return DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) + if type_ == _TYPE_SRV: + return DNSService( + domain, + type_, + class_, + ttl, + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_name(), + self.now, + ) + if type_ == _TYPE_HINFO: + return DNSHinfo( + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), + self.now, + ) + if type_ == _TYPE_AAAA: + return DNSAddress( + domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id + ) + if type_ == _TYPE_NSEC: + name_start = self.offset + return DNSNsec( + domain, + type_, + class_, + ttl, + self.read_name(), + self.read_bitmap(name_start + length), + self.now, + ) + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + return None + def read_bitmap(self, end: int) -> List[int]: """Reads an NSEC bitmap from the packet.""" rdtypes = [] @@ -236,39 +262,51 @@ def read_bitmap(self, end: int) -> List[int]: return rdtypes def read_name(self) -> str: - """Reads a domain name from the packet""" - result = '' - off = self.offset - next_ = -1 - first = off - + """Reads a domain name from the packet.""" + labels: List[str] = [] + self.seen_pointers.clear() + self.offset = self._decode_labels_at_offset(self.offset, labels) + labels.append("") + name = ".".join(labels) + if len(name) > MAX_NAME_LENGTH: + raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") + return name + + def _decode_labels_at_offset(self, off: int, labels: List[str]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. - while True: + while off < self.data_len: length = self.data[off] - off += 1 if length == 0: - break - t = length & 0xC0 - if t == 0x00: - # Convert to utf-8 - result += str(self.data[off : off + length], 'utf-8', 'replace') + '.' - off += length - elif t == 0xC0: - if next_ < 0: - next_ = off + 1 - off = ((length & 0x3F) << 8) | self.data[off] - if off >= first: - raise IncomingDecodeError(f"Bad domain name (circular) at {off}") - first = off - else: - raise IncomingDecodeError(f"Bad domain name at {off}") - - if next_ >= 0: - self.offset = next_ - else: - self.offset = off - - return result + return off + DNS_COMPRESSION_HEADER_LEN + + if length < 0x40: + label_idx = off + DNS_COMPRESSION_HEADER_LEN + labels.append(str(self.data[label_idx : label_idx + length], 'utf-8', 'replace')) + off += DNS_COMPRESSION_HEADER_LEN + length + continue + + if length < 0xC0: + raise IncomingDecodeError(f"DNS compression type {length} is unknown at {off}") + + # We have a DNS compression pointer + link = (length & 0x3F) * 256 + self.data[off + 1] + if link > self.data_len: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet") + if link == off: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself") + if link in self.seen_pointers: + raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again") + self.seen_pointers.add(link) + linked_labels = self.name_cache.get(link, []) + if not linked_labels: + self._decode_labels_at_offset(link, linked_labels) + self.name_cache[link] = linked_labels + labels.extend(linked_labels) + if len(labels) > MAX_DNS_LABELS: + raise IncomingDecodeError(f"Maximum dns labels reached while processing pointer at {off}") + return off + DNS_COMPRESSION_POINTER_LEN + + raise IncomingDecodeError("Corrupt packet received while decoding name") class DNSOutgoing(DNSMessage): From 6a140cc6b9c7e50e572456662d2f76f6fbc2ed25 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 5 Aug 2021 16:22:41 -0500 Subject: [PATCH 537/608] Update changelog for 0.33.3 (#936) --- README.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.rst b/README.rst index b53656be..6a1fd287 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,13 @@ See examples directory for more. Changelog ========= +0.33.3 +====== + +* Added support for forward dns compression pointers (#934) @bdraco + +* Provide sockname when logging a protocol error (#935) @bdraco + 0.33.2 ====== From 206671a1237ee8237d302b04c5a84158fed1d50b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 5 Aug 2021 16:54:55 -0500 Subject: [PATCH 538/608] =?UTF-8?q?Bump=20version:=200.33.2=20=E2=86=92=20?= =?UTF-8?q?0.33.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 65b769fd..81cf3077 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.33.2 +current_version = 0.33.3 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 69166393..4efdd056 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.33.2' +__version__ = '0.33.3' __license__ = 'LGPL' From 496ac44e99b56485cc9197490e71bb2dd7bec6f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=20Nov=C3=A1k?= Date: Fri, 6 Aug 2021 18:33:00 +0200 Subject: [PATCH 539/608] Ensure zeroconf can be loaded when the system disables IPv6 (#933) Co-authored-by: J. Nick Koston --- tests/services/test_info.py | 4 ++++ tests/test_asyncio.py | 6 +++++- tests/test_dns.py | 5 +++++ tests/test_handlers.py | 9 ++++++++- tests/test_protocol.py | 5 +++++ tests/utils/test_net.py | 4 ++++ zeroconf/_utils/net.py | 15 ++++++++++++--- zeroconf/const.py | 4 ---- 8 files changed, 43 insertions(+), 9 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 2060767f..0464ae7b 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -192,6 +192,8 @@ def test_service_info_rejects_expired_records(self): assert info.properties[b"ci"] == b"2" zc.close() + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_get_info_partial(self): zc = r.Zeroconf(interfaces=['127.0.0.1']) @@ -576,6 +578,8 @@ async def test_multiple_a_addresses(): await aiozc.async_close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_filter_address_by_type_from_service_info(): """Verify dns_addresses can filter by ipversion.""" desc = {'path': '/~paulsm/'} diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 9ec5e496..355b1b14 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -5,6 +5,7 @@ import asyncio import logging +import os import socket import time import threading @@ -32,7 +33,7 @@ from zeroconf._services.info import ServiceInfo from zeroconf._utils.time import current_time_millis -from . import _clear_cache +from . import _clear_cache, has_working_ipv6 log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -349,6 +350,9 @@ async def test_async_wait_unblocks_on_update() -> None: @pytest.mark.asyncio async def test_service_info_async_request() -> None: """Test registering services broadcasts and query with AsyncServceInfo.async_request.""" + if not has_working_ipv6() or os.environ.get('SKIP_IPV6'): + pytest.skip('Requires IPv6') + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) type_ = "_test1-srvc-type._tcp.local." name = "xxxyyy" diff --git a/tests/test_dns.py b/tests/test_dns.py index 071e1f65..a952b81e 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -4,6 +4,7 @@ """ Unit tests for zeroconf._dns. """ import logging +import os import socket import time import unittest @@ -18,6 +19,8 @@ ServiceInfo, ) +from . import has_working_ipv6 + log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -52,6 +55,8 @@ def test_dns_pointer_repr(self): pointer = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123') repr(pointer) + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_dns_address_repr(self): address = r.DNSAddress('irrelevant', const._TYPE_SOA, const._CLASS_IN, 1, b'a') assert repr(address).endswith("b'a'") diff --git a/tests/test_handlers.py b/tests/test_handlers.py index ebe19f41..3d05032b 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -5,6 +5,7 @@ import asyncio import logging +import os import pytest import socket import time @@ -19,7 +20,7 @@ from zeroconf.asyncio import AsyncZeroconf -from . import _clear_cache, _inject_response +from . import _clear_cache, _inject_response, has_working_ipv6 log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -274,6 +275,8 @@ def test_ptr_optimization(): zc.close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_any_query_for_ptr(): """Test that queries for ANY will return PTR records.""" zc = Zeroconf(interfaces=['127.0.0.1']) @@ -301,6 +304,8 @@ def test_any_query_for_ptr(): zc.close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_aaaa_query(): """Test that queries for AAAA records work.""" zc = Zeroconf(interfaces=['127.0.0.1']) @@ -326,6 +331,8 @@ def test_aaaa_query(): zc.close() +@unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') +@unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_a_and_aaaa_record_fate_sharing(): """Test that queries for AAAA always return A records in the additionals.""" zc = Zeroconf(interfaces=['127.0.0.1']) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 706afdd3..8c2f92c4 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -5,6 +5,7 @@ import copy import logging +import os import socket import struct import unittest @@ -18,6 +19,8 @@ DNSText, ) +from . import has_working_ipv6 + log = logging.getLogger('zeroconf') original_logging_level = logging.NOTSET @@ -468,6 +471,8 @@ def test_incoming_circular_reference(self): ) ).valid + @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') + @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_incoming_ipv6(self): addr = "2606:2800:220:1:248:1893:25c8:1946" # example.com packed = socket.inet_pton(socket.AF_INET6, addr) diff --git a/tests/utils/test_net.py b/tests/utils/test_net.py index 238e709c..41fdb7aa 100644 --- a/tests/utils/test_net.py +++ b/tests/utils/test_net.py @@ -190,6 +190,10 @@ def test_add_multicast_member(): with patch("socket.socket.setsockopt", side_effect=OSError(errno.ENODEV, None)): assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False + # No IPv6 support should return False for IPv6 + with patch("socket.inet_pton", side_effect=OSError()): + assert netutils.add_multicast_member(sock, ('2001:db8::', 1, 1)) is False + # No error should return True with patch("socket.socket.setsockopt"): assert netutils.add_multicast_member(sock, interface) is True diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index 3aafe768..bfae9db4 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -31,7 +31,7 @@ import ifaddr from .._logger import log -from ..const import _IPPROTO_IPV6, _MDNS_ADDR6_BYTES, _MDNS_ADDR_BYTES, _MDNS_PORT +from ..const import _IPPROTO_IPV6, _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT @enum.unique @@ -259,11 +259,20 @@ def add_multicast_member( log.debug('Adding %r (socket %d) to multicast group', interface, listen_socket.fileno()) try: if is_v6: + try: + mdns_addr6_bytes = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) + except OSError: + log.info( + 'Unable to translate IPv6 address when adding %s to multicast group, ' + 'this can happen if IPv6 is disabled on the system', + interface, + ) + return False iface_bin = struct.pack('@I', cast(int, interface[1])) - _value = _MDNS_ADDR6_BYTES + iface_bin + _value = mdns_addr6_bytes + iface_bin listen_socket.setsockopt(_IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, _value) else: - _value = _MDNS_ADDR_BYTES + socket.inet_aton(cast(str, interface)) + _value = socket.inet_aton(_MDNS_ADDR) + socket.inet_aton(cast(str, interface)) listen_socket.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, _value) except socket.error as e: _errno = get_errno(e) diff --git a/zeroconf/const.py b/zeroconf/const.py index 27dc817f..4c23310c 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -20,7 +20,6 @@ USA """ -import contextlib import re import socket @@ -44,10 +43,7 @@ # Some DNS constants _MDNS_ADDR = '224.0.0.251' -_MDNS_ADDR_BYTES = socket.inet_aton(_MDNS_ADDR) _MDNS_ADDR6 = 'ff02::fb' -with contextlib.suppress(OSError): # can't use AF_INET6, IPv6 is disabled - _MDNS_ADDR6_BYTES = socket.inet_pton(socket.AF_INET6, _MDNS_ADDR6) _MDNS_PORT = 5353 _DNS_PORT = 53 _DNS_HOST_TTL = 120 # two minute for host records (A, SRV etc) as-per RFC6762 From 858605db52f909d41198df76130597ff93f64cdd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 6 Aug 2021 11:48:43 -0500 Subject: [PATCH 540/608] Update changelog for 0.33.4 (#937) --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index 6a1fd287..ade5b8ee 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,11 @@ See examples directory for more. Changelog ========= +0.33.4 +====== + +* Ensure zeroconf can be loaded when the system disables IPv6 (#933) @che0 + 0.33.3 ====== From 7bbacd57a134c12ee1fb61d8318b312dfdae18f8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 6 Aug 2021 11:49:11 -0500 Subject: [PATCH 541/608] =?UTF-8?q?Bump=20version:=200.33.3=20=E2=86=92=20?= =?UTF-8?q?0.33.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 81cf3077..f982b806 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.33.3 +current_version = 0.33.4 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 4efdd056..f106525b 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.33.3' +__version__ = '0.33.4' __license__ = 'LGPL' From 55efb4169b588cef093f3065f3a894878ae8bd95 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 16:00:56 -0500 Subject: [PATCH 542/608] Implement Multicast Response Aggregation (#940) - Responses are now aggregated when possible per rules in RFC6762 section 6.4 - Responses that trigger the protection against against excessive packet flooding due to software bugs or malicious attack described in RFC6762 section 6 are delayed instead of discarding as it was causing responders that implement Passive Observation Of Failures (POOF) to evict the records. - Probe responses are now always sent immediately as there were cases where they would fail to be answered in time to defend a name. closes #939 --- tests/conftest.py | 13 + tests/services/test_types.py | 16 +- tests/test_asyncio.py | 40 ++- tests/test_core.py | 4 +- tests/test_handlers.py | 516 +++++++++++++++++++++++------------ zeroconf/_core.py | 41 ++- zeroconf/_handlers.py | 220 ++++++++++----- zeroconf/const.py | 2 + 8 files changed, 587 insertions(+), 265 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index d4ea1632..f900e094 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,10 @@ import pytest +import unittest + +from zeroconf import _core, const + @pytest.fixture(autouse=True) def verify_threads_ended(): @@ -15,3 +19,12 @@ def verify_threads_ended(): yield threads = frozenset(threading.enumerate()) - threads_before assert not threads + + +@pytest.fixture +def run_isolated(): + """Change the mDNS port to run the test in isolation.""" + with unittest.mock.patch.object(_core, "_MDNS_PORT", 5454), unittest.mock.patch.object( + const, "_MDNS_PORT", 5454 + ): + yield diff --git a/tests/services/test_types.py b/tests/services/test_types.py index f4206cf4..b1c312db 100644 --- a/tests/services/test_types.py +++ b/tests/services/test_types.py @@ -57,10 +57,10 @@ def test_integration_with_listener(self): ), patch.object( zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False ): - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) assert type_ in service_types _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) assert type_ in service_types finally: @@ -94,10 +94,10 @@ def test_integration_with_listener_v6_records(self): ), patch.object( zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False ): - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) assert type_ in service_types _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) assert type_ in service_types finally: @@ -131,10 +131,10 @@ def test_integration_with_listener_ipv6(self): ), patch.object( zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False ): - service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=0.5) + service_types = ZeroconfServiceTypes.find(ip_version=r.IPVersion.V6Only, timeout=2) assert type_ in service_types _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) assert type_ in service_types finally: @@ -167,10 +167,10 @@ def test_integration_with_subtype_and_listener(self): ), patch.object( zeroconf_registrar.engine.protocols[1], "suppress_duplicate_packet", return_value=False ): - service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=0.5) + service_types = ZeroconfServiceTypes.find(interfaces=['127.0.0.1'], timeout=2) assert discovery_type in service_types _clear_cache(zeroconf_registrar) - service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=0.5) + service_types = ZeroconfServiceTypes.find(zc=zeroconf_registrar, timeout=2) assert discovery_type in service_types finally: diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 355b1b14..39cad5b9 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -9,7 +9,7 @@ import socket import time import threading -from unittest.mock import patch +from unittest.mock import ANY, call, patch, MagicMock import pytest @@ -18,6 +18,7 @@ from zeroconf import ( DNSIncoming, DNSOutgoing, + DNSQuestion, DNSPointer, DNSService, DNSAddress, @@ -27,6 +28,7 @@ const, ) from zeroconf.const import _LISTENER_TIME +from zeroconf._core import AsyncListener from zeroconf._exceptions import BadTypeInNameException, NonUniqueNameException, ServiceNameAlreadyRegistered from zeroconf._services import ServiceListener import zeroconf._services.browser as _services_browser @@ -615,10 +617,10 @@ async def test_async_zeroconf_service_types(): await asyncio.sleep(0.2) _clear_cache(zeroconf_registrar.zeroconf) try: - service_types = await AsyncZeroconfServiceTypes.async_find(interfaces=['127.0.0.1'], timeout=0.5) + service_types = await AsyncZeroconfServiceTypes.async_find(interfaces=['127.0.0.1'], timeout=2) assert type_ in service_types _clear_cache(zeroconf_registrar.zeroconf) - service_types = await AsyncZeroconfServiceTypes.async_find(aiozc=zeroconf_registrar, timeout=0.5) + service_types = await AsyncZeroconfServiceTypes.async_find(aiozc=zeroconf_registrar, timeout=2) assert type_ in service_types finally: @@ -951,3 +953,35 @@ async def test_async_request_timeout(): # 3000ms for the default timeout # 1000ms for loaded systems + schedule overhead assert (end_time - start_time) < 3000 + 1000 + + +@pytest.mark.asyncio +async def test_legacy_unicast_response(run_isolated): + """Verify legacy unicast responses include questions and correct id.""" + type_ = "_mservice._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + + aiozc.zeroconf.registry.async_add(info) + query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888) + question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + + with patch.object(aiozc.zeroconf, "async_send") as send_mock: + aiozc.zeroconf.engine.protocols[0].datagram_received(query.packets()[0], ('127.0.0.1', 6503)) + + calls = send_mock.mock_calls + assert calls == [call(ANY, '127.0.0.1', 6503, ())] + outgoing = send_mock.call_args[0][0] + assert isinstance(outgoing, DNSOutgoing) + assert outgoing.questions == [question] + assert outgoing.id == query.id + await aiozc.async_close() diff --git a/tests/test_core.py b/tests/test_core.py index fd45b1ee..e2420a78 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -696,10 +696,10 @@ def test_guard_against_oversized_packets(): listener = _core.AsyncListener(zc) listener.transport = unittest.mock.MagicMock() - listener.datagram_received(ok_packet, ('127.0.0.1', 5353)) + listener.datagram_received(ok_packet, ('127.0.0.1', const._MDNS_PORT)) assert zc.cache.async_get_unique(okpacket_record) is not None - listener.datagram_received(over_sized_packet, ('127.0.0.1', 5353)) + listener.datagram_received(over_sized_packet, ('127.0.0.1', const._MDNS_PORT)) assert ( zc.cache.async_get_unique( r.DNSText( diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 3d05032b..9049408d 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -16,7 +16,7 @@ import zeroconf as r from zeroconf import ServiceInfo, Zeroconf, current_time_millis from zeroconf import const -from zeroconf._dns import DNSRRSet +from zeroconf._handlers import construct_outgoing_multicast_answers from zeroconf.asyncio import AsyncZeroconf @@ -101,10 +101,10 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT - )[1] - _process_outgoing_packet(multicast_out) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) # The additonals should all be suppresed since they are all in the answers section # @@ -138,11 +138,10 @@ def _process_outgoing_packet(out): query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN)) query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN)) - _process_outgoing_packet( - zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT - )[1] + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) + _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 @@ -208,7 +207,7 @@ def test_register_and_lookup_type_by_uppercase_name(self): out = r.DNSOutgoing(const._FLAGS_QR_QUERY) out.add_question(r.DNSQuestion(type_.upper(), const._TYPE_PTR, const._CLASS_IN)) zc.send(out) - time.sleep(0.5) + time.sleep(1) info = ServiceInfo(type_, registration_name) info.load_from_cache(zc) assert info.addresses == [socket.inet_pton(socket.AF_INET, "1.2.3.4")] @@ -237,11 +236,15 @@ def test_ptr_optimization(): # Verify we won't respond for 1s with the same multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert unicast_out is None - assert multicast_out is None + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + # Since we sent the PTR in the last second, they + # should end up in the delayed at least one second bucket + assert question_answers.mcast_aggregate_last_second # Clear the cache to allow responding again _clear_cache(zc) @@ -249,17 +252,17 @@ def test_ptr_optimization(): # Verify we will now respond query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out.id == query.id - assert unicast_out is None - assert multicast_out is not None + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate_last_second has_srv = has_txt = has_a = False nbr_additionals = 0 - nbr_answers = len(multicast_out.answers) - nbr_authorities = len(multicast_out.authorities) - for answer in multicast_out.additionals: + nbr_answers = len(question_answers.mcast_aggregate) + additionals = set().union(*question_answers.mcast_aggregate.values()) + for answer in additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: has_srv = True @@ -267,7 +270,7 @@ def test_ptr_optimization(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert nbr_answers == 1 and nbr_additionals == 3 assert has_srv and has_txt and has_a # unregister @@ -278,7 +281,7 @@ def test_ptr_optimization(): @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_any_query_for_ptr(): - """Test that queries for ANY will return PTR records.""" + """Test that queries for ANY will return PTR records and the response is aggregated.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_anyptr._tcp.local." name = "knownname" @@ -294,11 +297,10 @@ def test_any_query_for_ptr(): question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert multicast_out.answers[0][0].name == type_ - assert multicast_out.answers[0][0].alias == registration_name + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_aggregate) + assert mcast_answers[0].name == type_ + assert mcast_answers[0].alias == registration_name # unregister zc.registry.async_remove(info) zc.close() @@ -307,7 +309,7 @@ def test_any_query_for_ptr(): @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_aaaa_query(): - """Test that queries for AAAA records work.""" + """Test that queries for AAAA records work and should respond right away.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_knownaaaservice._tcp.local." name = "knownname" @@ -322,10 +324,9 @@ def test_aaaa_query(): question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert multicast_out.answers[0][0].address == ipv6_address + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + mcast_answers = list(question_answers.mcast_now) + assert mcast_answers[0].address == ipv6_address # unregister zc.registry.async_remove(info) zc.close() @@ -334,7 +335,7 @@ def test_aaaa_query(): @unittest.skipIf(not has_working_ipv6(), 'Requires IPv6') @unittest.skipIf(os.environ.get('SKIP_IPV6'), 'IPv6 tests disabled') def test_a_and_aaaa_record_fate_sharing(): - """Test that queries for AAAA always return A records in the additionals.""" + """Test that queries for AAAA always return A records in the additionals and should respond right away.""" zc = Zeroconf(interfaces=['127.0.0.1']) type_ = "_a-and-aaaa-service._tcp.local." name = "knownname" @@ -356,31 +357,25 @@ def test_a_and_aaaa_record_fate_sharing(): question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - answers = DNSRRSet([answer[0] for answer in multicast_out.answers]) - additionals = DNSRRSet(multicast_out.additionals) - assert aaaa_record in answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + additionals = set().union(*question_answers.mcast_now.values()) + assert aaaa_record in question_answers.mcast_now assert a_record in additionals - assert len(multicast_out.answers) == 1 - assert len(multicast_out.additionals) == 1 + assert len(question_answers.mcast_now) == 1 + assert len(additionals) == 1 # Test A query generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - _, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - answers = DNSRRSet([answer[0] for answer in multicast_out.answers]) - additionals = DNSRRSet(multicast_out.additionals) - - assert a_record in answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in question_answers.mcast_now assert aaaa_record in additionals - assert len(multicast_out.answers) == 1 - assert len(multicast_out.additionals) == 1 + assert len(question_answers.mcast_now) == 1 + assert len(additionals) == 1 + # unregister zc.registry.async_remove(info) zc.close() @@ -406,16 +401,15 @@ def test_unicast_response(): # query query = r.DNSOutgoing(const._FLAGS_QR_QUERY) query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN)) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", 1234 + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], True ) - for out in (unicast_out, multicast_out): - assert out.id == query.id + for answers in (question_answers.ucast, question_answers.mcast_aggregate): has_srv = has_txt = has_a = False nbr_additionals = 0 - nbr_answers = len(out.answers) - nbr_authorities = len(out.authorities) - for answer in out.additionals: + nbr_answers = len(answers) + additionals = set().union(*answers.values()) + for answer in additionals: nbr_additionals += 1 if answer.type == const._TYPE_SRV: has_srv = True @@ -423,7 +417,7 @@ def test_unicast_response(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert nbr_answers == 1 and nbr_additionals == 3 assert has_srv and has_txt and has_a # unregister @@ -431,6 +425,48 @@ def test_unicast_response(): zc.close() +@pytest.mark.asyncio +async def test_probe_answered_immediately(): + """Verify probes are responded to immediately.""" + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # service definition + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + zc.registry.async_add(info) + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + assert question_answers.mcast_now + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + question.unicast = True + query.add_question(question) + query.add_authorative_answer(info.dns_pointer()) + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False + ) + assert question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + zc.close() + + def test_qu_response(): """Handle multicast incoming with the QU bit set.""" # instantiate a zeroconf instance @@ -459,21 +495,20 @@ def test_qu_response(): # register zc.register_service(info) - def _validate_complete_response(query, out): - assert out.id == query.id + def _validate_complete_response(answers): has_srv = has_txt = has_a = False - nbr_additionals = 0 - nbr_answers = len(out.answers) - nbr_authorities = len(out.authorities) - for answer in out.additionals: - nbr_additionals += 1 + nbr_answers = len(answers.keys()) + additionals = set().union(*answers.values()) + nbr_additionals = len(additionals) + + for answer in additionals: if answer.type == const._TYPE_SRV: has_srv = True elif answer.type == const._TYPE_TXT: has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 and nbr_authorities == 0 + assert nbr_answers == 1 and nbr_additionals == 3 assert has_srv and has_txt and has_a # With QU should respond to only unicast when the answer has been recently multicast @@ -483,11 +518,13 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out is None - _validate_complete_response(query, unicast_out) + _validate_complete_response(question_answers.ucast) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second _clear_cache(zc) # With QU should respond to only multicast since the response hasn't been seen since 75% of the ttl @@ -496,11 +533,13 @@ def _validate_complete_response(query, out): question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert unicast_out is None - _validate_complete_response(query, multicast_out) + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate + _validate_complete_response(question_answers.mcast_now) # With QU set and an authorative answer (probe) should respond to both unitcast and multicast since the response hasn't been seen since 75% of the ttl query = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -509,24 +548,28 @@ def _validate_complete_response(query, out): assert question.unicast is True query.add_question(question) query.add_authorative_answer(info2.dns_pointer()) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - _validate_complete_response(query, unicast_out) - _validate_complete_response(query, multicast_out) + _validate_complete_response(question_answers.ucast) + _validate_complete_response(question_answers.mcast_now) - _inject_response(zc, r.DNSIncoming(multicast_out.packets()[0])) + _inject_response( + zc, r.DNSIncoming(construct_outgoing_multicast_answers(question_answers.mcast_now).packets()[0]) + ) # With the cache repopulated; should respond to only unicast when the answer has been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) question.unicast = True # Set the QU bit assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out is None - _validate_complete_response(query, unicast_out) + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + _validate_complete_response(question_answers.ucast) # unregister zc.unregister_service(info) zc.close() @@ -551,34 +594,33 @@ def test_known_answer_supression(): question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is not None and multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) generated.add_answer_at_time(info.dns_pointer(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - # If the answer is suppressed, the additional should be suppresed as well - assert not multicast_out or not multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # Test A supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is not None and multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN) @@ -586,56 +628,55 @@ def test_known_answer_supression(): for dns_address in info.dns_addresses(): generated.add_answer_at_time(dns_address, now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert not multicast_out or not multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # Test SRV supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is not None and multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) generated.add_question(question) generated.add_answer_at_time(info.dns_service(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - # If the answer is suppressed, the additional should be suppresed as well - assert not multicast_out or not multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # Test TXT supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is not None and multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN) generated.add_question(question) generated.add_answer_at_time(info.dns_text(), now) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert not multicast_out or not multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # unregister zc.registry.async_remove(info) @@ -684,11 +725,11 @@ def test_multi_packet_known_answer_supression(): generated.add_answer_at_time(info3.dns_pointer(), now) packets = generated.packets() assert len(packets) > 1 - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is None + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # unregister zc.registry.async_remove(info) zc.registry.async_remove(info2) @@ -725,11 +766,11 @@ def test_known_answer_supression_service_type_enumeration_query(): question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) generated.add_question(question) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is not None and multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN) @@ -755,11 +796,11 @@ def test_known_answer_supression_service_type_enumeration_query(): now, ) packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert not multicast_out or not multicast_out.answers + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second # unregister zc.registry.async_remove(info) @@ -815,12 +856,16 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out is None - assert a_record in unicast_out.additionals - assert unicast_out.answers[0][0] == ptr_record + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + additionals = set().union(*question_answers.ucast.values()) + assert a_record in additionals + assert ptr_record in question_answers.ucast # Remove the 50% A record and add a 100% A record zc.cache.async_remove_records([a_record]) @@ -835,12 +880,15 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out is None - assert a_record in unicast_out.additionals - assert unicast_out.answers[0][0] == ptr_record + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + additionals = set().union(*question_answers.ucast.values()) + assert a_record in additionals + assert ptr_record in question_answers.ucast # Remove the 100% PTR record and add a 50% PTR record zc.cache.async_remove_records([ptr_record]) @@ -855,15 +903,17 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): assert question.unicast is True query.add_question(question) - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out.answers[0][0] == ptr_record - assert a_record in multicast_out.additionals - assert info.dns_text() in multicast_out.additionals - assert info.dns_service() in multicast_out.additionals - - assert unicast_out is None + assert not question_answers.ucast + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in additionals + assert info.dns_text() in additionals + assert info.dns_service() in additionals + assert ptr_record in question_answers.mcast_now # Ask 2 QU questions, with info the PTR is at 50%, with info2 the PTR is at 100% # We should get back a unicast reply for info2, but info should be multicasted since its within 75% of its TTL @@ -881,18 +931,23 @@ async def test_qu_response_only_sends_additionals_if_sends_answer(): query.add_question(question) zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT + question_answers = zc.query_handler.async_response( + [r.DNSIncoming(packet) for packet in query.packets()], False ) - assert multicast_out.answers[0][0] == info.dns_pointer() - assert info.dns_addresses()[0] in multicast_out.additionals - assert info.dns_text() in multicast_out.additionals - assert info.dns_service() in multicast_out.additionals + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + + mcast_now_additionals = set().union(*question_answers.mcast_now.values()) + assert a_record in mcast_now_additionals + assert info.dns_text() in mcast_now_additionals + assert info.dns_addresses()[0] in mcast_now_additionals + assert info.dns_pointer() in question_answers.mcast_now - assert unicast_out.answers[0][0] == info2.dns_pointer() - assert info2.dns_addresses()[0] in unicast_out.additionals - assert info2.dns_text() in unicast_out.additionals - assert info2.dns_service() in unicast_out.additionals + ucast_additionals = set().union(*question_answers.ucast.values()) + assert info2.dns_pointer() in question_answers.ucast + assert info2.dns_text() in ucast_additionals + assert info2.dns_service() in ucast_additionals + assert info2.dns_addresses()[0] in ucast_additionals # unregister zc.registry.async_remove(info) @@ -1045,11 +1100,11 @@ async def test_questions_query_handler_populates_the_question_history_from_qm_qu generated.add_answer_at_time(known_answer, 0) now = r.current_time_millis() packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is None + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second assert zc.question_history.suppresses(question, now, {known_answer}) await aiozc.async_close() @@ -1072,11 +1127,11 @@ async def test_questions_query_handler_does_not_put_qu_questions_in_history(): generated.add_answer_at_time(known_answer, 0) now = r.current_time_millis() packets = generated.packets() - unicast_out, multicast_out = zc.query_handler.async_response( - [r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT - ) - assert unicast_out is None - assert multicast_out is None + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + assert not question_answers.mcast_now + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second assert not zc.question_history.suppresses(question, now, {known_answer}) await aiozc.async_close() @@ -1161,3 +1216,104 @@ async def test_duplicate_goodbye_answers_in_packet(): incoming = r.DNSIncoming(response.packets()[0]) zc.record_manager.async_updates_from_response(incoming) await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_response_aggregation_timings(run_isolated): + """Verify multicast respones are aggregated.""" + type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." + type_3 = "_mservice3._tcp.local." + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" + registration_name3 = f"{name}.{type_3}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + info3 = ServiceInfo( + type_3, registration_name3, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + aiozc.zeroconf.registry.async_add(info) + aiozc.zeroconf.registry.async_add(info2) + aiozc.zeroconf.registry.async_add(info3) + + query = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question = r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) + query.add_question(question) + + query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + query2.add_question(question2) + + query3 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question3 = r.DNSQuestion(info3.type, const._TYPE_PTR, const._CLASS_IN) + query3.add_question(question3) + + query4 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + query4.add_question(question) + query4.add_question(question2) + + zc = aiozc.zeroconf + protocol = zc.engine.protocols[0] + + with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock: + protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.7) + + # Should aggregate into a single answer with up to a 500ms + 120ms delay + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info.dns_pointer() in incoming.answers + assert info2.dns_pointer() in incoming.answers + send_mock.reset_mock() + + protocol.datagram_received(query3.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.3) + + # Should send within 120ms since there are no other + # answers to aggregate with + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info3.dns_pointer() in incoming.answers + send_mock.reset_mock() + + # Because the response was sent in the last second we need to make + # sure the next answer is delayed at least a second + aiozc.zeroconf.engine.protocols[0].datagram_received( + query4.packets()[0], ('127.0.0.1', const._MDNS_PORT) + ) + await asyncio.sleep(0.5) + + # After 0.5 seconds it should not have been sent + # Protect the network against excessive packet flooding + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + calls = send_mock.mock_calls + assert len(calls) == 0 + send_mock.reset_mock() + + await asyncio.sleep(1.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + assert info.dns_pointer() in incoming.answers + + await aiozc.async_close() diff --git a/zeroconf/_core.py b/zeroconf/_core.py index b2320601..20251f0b 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -32,7 +32,13 @@ from ._cache import DNSCache from ._dns import DNSQuestion, DNSQuestionType from ._exceptions import NonUniqueNameException -from ._handlers import QueryHandler, RecordManager +from ._handlers import ( + MulticastOutgoingQueue, + QueryHandler, + RecordManager, + construct_outgoing_multicast_answers, + construct_outgoing_unicast_answers, +) from ._history import QuestionHistory from ._logger import QuietLogger, log from ._protocol import DNSIncoming, DNSOutgoing @@ -70,12 +76,15 @@ _MDNS_ADDR, _MDNS_ADDR6, _MDNS_PORT, + _ONE_SECOND, _REGISTER_TIME, _TYPE_PTR, _UNREGISTER_TIME, ) _TC_DELAY_RANDOM_INTERVAL = (400, 500) + + _CLOSE_TIMEOUT = 3000 # ms _REGISTER_BROADCASTS = 3 @@ -394,6 +403,8 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None + self._out_queue = MulticastOutgoingQueue(self) + self.start() def start(self) -> None: @@ -717,11 +728,24 @@ def handle_assembled_query( or the timer expires. If the TC bit is not set, a single packet will be in packets. """ - unicast_out, multicast_out = self.query_handler.async_response(packets, addr, port) - if unicast_out: - self.async_send(unicast_out, addr, port, v6_flow_scope) - if multicast_out: - self.async_send(multicast_out, None, _MDNS_PORT) + now = packets[0].now + ucast_source = port != _MDNS_PORT + question_answers = self.query_handler.async_response(packets, ucast_source) + if question_answers.ucast: + questions = packets[0].questions + id_ = packets[0].id + out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) + self.async_send(out, addr, port, v6_flow_scope) + if question_answers.mcast_now: + out = construct_outgoing_multicast_answers(question_answers.mcast_aggregate) + self.async_send(out) + if question_answers.mcast_aggregate: + self._out_queue.async_add(now, question_answers.mcast_aggregate, 0) + if question_answers.mcast_aggregate_last_second: + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 + # If we broadcast it in the last second, we have to delay + # at least a second before we send it again + self._out_queue.async_add(now, question_answers.mcast_aggregate_last_second, _ONE_SECOND) def send( self, @@ -742,6 +766,9 @@ def async_send( v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Sends an outgoing packet.""" + if self._GLOBAL_DONE: + return + for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) @@ -756,8 +783,6 @@ def async_send( packet, ) for transport in self.engine.senders: - if self._GLOBAL_DONE: - return s = transport.get_extra_info('socket') if addr is None: real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 29ea0b6b..76d5efcd 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -21,7 +21,9 @@ """ import itertools -from typing import Dict, Iterable, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +import random +from collections import deque +from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from ._cache import DNSCache, _UniqueRecordsType from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord @@ -30,14 +32,14 @@ from ._protocol import DNSIncoming, DNSOutgoing from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener -from ._utils.time import current_time_millis +from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CLASS_IN, _DNS_OTHER_TTL, _DNS_PTR_MIN_TTL, _FLAGS_AA, _FLAGS_QR_RESPONSE, - _MDNS_PORT, + _ONE_SECOND, _SERVICE_TYPE_ENUMERATION_NAME, _TYPE_A, _TYPE_AAAA, @@ -53,6 +55,59 @@ _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] +_MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) +_MAX_MULTICAST_DELAY = 500 # ms +_RESPOND_IMMEDIATE_TYPES = {_TYPE_SRV, _TYPE_A, _TYPE_AAAA} + + +class QuestionAnswers(NamedTuple): + ucast: _AnswerWithAdditionalsType + mcast_now: _AnswerWithAdditionalsType + mcast_aggregate: _AnswerWithAdditionalsType + mcast_aggregate_last_second: _AnswerWithAdditionalsType + + +class AnswerGroup(NamedTuple): + """A group of answers scheduled to be sent at the same time.""" + + send_after: float # Must be sent after this time + send_before: float # Must be sent before this time + answers: _AnswerWithAdditionalsType + + +def _message_is_probe(msg: DNSIncoming) -> bool: + return msg.num_authorities > 0 + + +def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True) + _add_answers_additionals(out, answers) + return out + + +def construct_outgoing_unicast_answers( + answers: _AnswerWithAdditionalsType, ucast_source: bool, questions: List[DNSQuestion], id_: int +) -> DNSOutgoing: + """Add answers and additionals to a DNSOutgoing.""" + out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=False, id_=id_) + # Adding the questions back when the source is legacy unicast behavior + if ucast_source: + for question in questions: + out.add_question(question) + _add_answers_additionals(out, answers) + return out + + +def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: + # Find additionals and suppress any additionals that are already in answers + additionals: Set[DNSRecord] = set().union(*answers.values()) # type: ignore + additionals -= answers.keys() + for answer in answers: + out.add_answer_at_time(answer, 0) + for additional in additionals: + out.add_additional_answer(additional) + def sanitize_incoming_record(record: DNSRecord) -> None: """Protect zeroconf from records that can cause denial of service. @@ -74,16 +129,17 @@ def sanitize_incoming_record(record: DNSRecord) -> None: class _QueryResponse: """A pair for unicast and multicast DNSOutgoing responses.""" - def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> None: + def __init__(self, cache: DNSCache, msgs: List[DNSIncoming]) -> None: """Build a query response.""" - self._msg = msg - self._is_probe = msg.num_authorities > 0 - self._ucast_source = ucast_source - self._now = current_time_millis() + self._is_probe = any(_message_is_probe(msg) for msg in msgs) + self._msg = msgs[0] + self._now = self._msg.now self._cache = cache self._additionals: _AnswerWithAdditionalsType = {} self._ucast: Set[DNSRecord] = set() - self._mcast: Set[DNSRecord] = set() + self._mcast_now: Set[DNSRecord] = set() + self._mcast_aggregate: Set[DNSRecord] = set() + self._mcast_aggregate_last_second: Set[DNSRecord] = set() def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast QU query.""" @@ -92,7 +148,7 @@ def add_qu_question_response(self, answers: _AnswerWithAdditionalsType) -> None: if self._is_probe: self._ucast.add(record) if not self._has_mcast_within_one_quarter_ttl(record): - self._mcast.add(record) + self._mcast_now.add(record) elif not self._is_probe: self._ucast.add(record) @@ -104,46 +160,32 @@ def add_ucast_question_response(self, answers: _AnswerWithAdditionalsType) -> No def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> None: """Generate a response to a multicast query.""" self._additionals.update(answers) - self._mcast.update(answers.keys()) - - def outgoing_unicast(self) -> Optional[DNSOutgoing]: - """Build the outgoing unicast response.""" - ucastout = self._construct_outgoing_from_record_set(self._ucast, False) - # Adding the questions back when the source is legacy unicast behavior - if ucastout and self._ucast_source: - for question in self._msg.questions: - ucastout.add_question(question) - return ucastout - - def outgoing_multicast(self) -> Optional[DNSOutgoing]: - """Build the outgoing multicast response.""" - if not self._is_probe: - self._suppress_mcasts_from_last_second(self._mcast) - return self._construct_outgoing_from_record_set(self._mcast, True) - - def _construct_outgoing_from_record_set( - self, answers_rrset: Set[DNSRecord], multicast: bool - ) -> Optional[DNSOutgoing]: - """Add answers and additionals to a DNSOutgoing.""" - # Find additionals and suppress any additionals that are already in answers - additionals_rrset = self._additionals_from_answers_rrset(answers_rrset) - answers_rrset - if not answers_rrset: - return None - - out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=multicast, id_=self._msg.id) - for answer in answers_rrset: - out.add_answer_at_time(answer, 0) - for additional in additionals_rrset: - out.add_additional_answer(additional) - return out - - def _additionals_from_answers_rrset(self, rrset: Set[DNSRecord]) -> Set[DNSRecord]: - additionals: Set[DNSRecord] = set() - return additionals.union(*(self._additionals[record] for record in rrset)) - - def _suppress_mcasts_from_last_second(self, rrset: Set[DNSRecord]) -> None: - """Remove any records that were already sent in the last second.""" - rrset -= {record for record in rrset if self._has_mcast_record_in_last_second(record)} + for answer in answers: + if self._is_probe: + self._mcast_now.add(answer) + continue + + if self._has_mcast_record_in_last_second(answer): + self._mcast_aggregate_last_second.add(answer) + elif len(self._msg.questions) == 1 and self._msg.questions[0].type in _RESPOND_IMMEDIATE_TYPES: + self._mcast_now.add(answer) + else: + self._mcast_aggregate.add(answer) + + def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType: + """Create answers with additionals from an rrset.""" + return {record: self._additionals[record] for record in rrset} + + def answers( + self, + ) -> QuestionAnswers: + """Return answer sets that will be queued.""" + return QuestionAnswers( + self._generate_answers_with_additionals(self._ucast), + self._generate_answers_with_additionals(self._mcast_now), + self._generate_answers_with_additionals(self._mcast_aggregate), + self._generate_answers_with_additionals(self._mcast_aggregate_last_second), + ) def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: """Check to see if a record has been mcasted recently. @@ -160,12 +202,12 @@ def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool: return bool(maybe_entry and maybe_entry.is_recent(self._now)) def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: - """Remove answers that were just broadcast + """Check if an answer was seen in the last second. Protect the network against excessive packet flooding https://datatracker.ietf.org/doc/html/rfc6762#section-14 """ maybe_entry = self._cache.async_get_unique(cast(_UniqueRecordsType, record)) - return bool(maybe_entry and self._now - maybe_entry.created < 1000) + return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND) class QueryHandler: @@ -229,13 +271,14 @@ def _add_address_answers( def _answer_question( self, question: DNSQuestion, - answer_set: _AnswerWithAdditionalsType, known_answers: DNSRRSet, now: float, - ) -> None: + ) -> _AnswerWithAdditionalsType: + answer_set: _AnswerWithAdditionalsType = {} + if question.type == _TYPE_PTR and question.name.lower() == _SERVICE_TYPE_ENUMERATION_NAME: self._add_service_type_enumeration_query_answers(answer_set, known_answers, now) - return + return answer_set type_ = question.type @@ -259,24 +302,26 @@ def _answer_question( if not known_answers.suppresses(dns_text): answer_set[dns_text] = set() + return answer_set + def async_response( # pylint: disable=unused-argument - self, msgs: List[DNSIncoming], addr: Optional[str], port: int - ) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]: + self, msgs: List[DNSIncoming], ucast_source: bool + ) -> QuestionAnswers: """Deal with incoming query packets. Provides a response if possible. This function must be run in the event loop as it is not threadsafe. """ - ucast_source = port != _MDNS_PORT - known_answers = DNSRRSet(itertools.chain(*(msg.answers for msg in msgs))) - query_res = _QueryResponse(self.cache, msgs[0], ucast_source) + known_answers = DNSRRSet( + itertools.chain(*(msg.answers for msg in msgs if not _message_is_probe(msg))) + ) + query_res = _QueryResponse(self.cache, msgs) for msg in msgs: for question in msg.questions: if not question.unicast: self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup)) - answer_set: _AnswerWithAdditionalsType = {} - self._answer_question(question, answer_set, known_answers, msg.now) + answer_set = self._answer_question(question, known_answers, msg.now) if not ucast_source and question.unicast: query_res.add_qu_question_response(answer_set) continue @@ -286,7 +331,7 @@ def async_response( # pylint: disable=unused-argument # source as long as we haven't done it recently (75% of ttl) query_res.add_mcast_question_response(answer_set) - return query_res.outgoing_unicast(), query_res.outgoing_multicast() + return query_res.answers() class RecordManager: @@ -397,7 +442,7 @@ def _async_mark_unique_cached_records_older_than_1s_to_expire( answers_rrset = DNSRRSet(answers) for name, type_, class_ in unique_types: for entry in self.cache.async_all_by_details(name, type_, class_): - if (now - entry.created > 1000) and entry not in answers_rrset: + if (now - entry.created > _ONE_SECOND) and entry not in answers_rrset: # Expire in 1s entry.set_created_ttl(now, 1) @@ -449,3 +494,50 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: self.zc.async_notify_all() except ValueError as e: log.exception('Failed to remove listener: %r', e) + + +class MulticastOutgoingQueue: + """An outgoing queue used to aggregate multicast responses.""" + + def __init__(self, zeroconf: 'Zeroconf') -> None: + self.zc = zeroconf + self.queue: deque = deque() + + def async_add(self, now: float, answers: _AnswerWithAdditionalsType, additional_delay: int) -> None: + """Add a group of answers with additionals to the outgoing queue.""" + assert self.zc.loop is not None + random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + additional_delay + send_after = now + random_delay + send_before = now + _MAX_MULTICAST_DELAY + additional_delay + if not len(self.queue): + self.zc.loop.call_later(millis_to_seconds(random_delay), self._async_ready) + self.queue.append(AnswerGroup(send_after, send_before, answers)) + + def _async_ready(self) -> None: + """Process anything in the queue that is ready.""" + assert self.zc.loop is not None + now = current_time_millis() + + if len(self.queue) > 1 and self.queue[0].send_before > now: + # There is more than one answer in the queue, + # delay until we have to send it (first answer group reaches send_before) + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self._async_ready) + return + + answers: _AnswerWithAdditionalsType = {} + # Add all groups that can be sent now + while len(self.queue) and self.queue[0].send_after <= now: + answers.update(self.queue.popleft().answers) + + if len(self.queue): + # If there are still groups in the queue that are not ready to send + # be sure we schedule them to go out later + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self._async_ready) + + if answers: + # If we have the same answer scheduled to go out, remove it + for pending in self.queue: + for record in answers: + pending.answers.pop(record, None) + + self.zc.async_send(construct_outgoing_multicast_answers(answers)) diff --git a/zeroconf/const.py b/zeroconf/const.py index 4c23310c..ff5cc3a2 100644 --- a/zeroconf/const.py +++ b/zeroconf/const.py @@ -34,6 +34,8 @@ _BROWSER_BACKOFF_LIMIT = 3600 # s _CACHE_CLEANUP_INTERVAL = 10000 # ms _LOADED_SYSTEM_TIMEOUT = 10 # s +_ONE_SECOND = 1000 # ms + # If the system is loaded or the event # loop was blocked by another task that was doing I/O in the loop # (shouldn't happen but it does in practice) we need to give From 342532e1d13ac24673735dc467a79edebdfb9362 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 16:19:07 -0500 Subject: [PATCH 543/608] Update changelog for 0.34.0 (#941) --- README.rst | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index ade5b8ee..36bcc7b3 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,23 @@ See examples directory for more. Changelog ========= +0.34.0 +====== + +* Implemented Multicast Response Aggregation (#940) @bdraco + + Responses are now aggregated when possible per rules in RFC6762 + section 6.4 + + Responses that trigger the protection against against excessive + packet flooding due to software bugs or malicious attack described + in RFC6762 section 6 are delayed instead of discarding as it was + causing responders that implement Passive Observation Of Failures + (POOF) to evict the records. + + Probe responses are now always sent immediately as there were cases + where they would fail to be answered in time to defend a name. + 0.33.4 ====== @@ -149,7 +166,6 @@ Changelog ====== * Added support for forward dns compression pointers (#934) @bdraco - * Provide sockname when logging a protocol error (#935) @bdraco 0.33.2 @@ -161,7 +177,6 @@ Changelog from the cache when the second goodbye answer in the same packet was processed Fixed #926 - * Skip ipv6 interfaces that return ENODEV (#930) @bdraco 0.33.1 From 549ac3de27eb3924cc7967088c3d316184722b9d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 16:28:47 -0500 Subject: [PATCH 544/608] =?UTF-8?q?Bump=20version:=200.33.4=20=E2=86=92=20?= =?UTF-8?q?0.34.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index f982b806..f1d01ab4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.33.4 +current_version = 0.34.0 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f106525b..bec24192 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.33.4' +__version__ = '0.34.0' __license__ = 'LGPL' From de96e2bf01af68d754bb7c71da949e30de88a77b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 20:56:46 -0500 Subject: [PATCH 545/608] Ensure multicast aggregation sends responses within 620ms (#942) --- tests/test_handlers.py | 66 ++++++++++++++++++++++++++++++++++++++++++ zeroconf/_core.py | 7 +++-- zeroconf/_handlers.py | 9 +++--- 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 9049408d..c573c411 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1317,3 +1317,69 @@ async def test_response_aggregation_timings(run_isolated): assert info.dns_pointer() in incoming.answers await aiozc.async_close() + + +@pytest.mark.asyncio +async def test_response_aggregation_timings_multiple(run_isolated): + """Verify multicast responses that are aggregated do not take longer than 620ms to send. + + 620ms is the maximum random delay of 120ms and 500ms additional for aggregation.""" + type_2 = "_mservice2._tcp.local." + + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + await aiozc.zeroconf.async_wait_for_start() + + name = "xxxyyy" + registration_name2 = f"{name}.{type_2}" + + desc = {'path': '/~paulsm/'} + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + aiozc.zeroconf.registry.async_add(info2) + + query2 = r.DNSOutgoing(const._FLAGS_QR_QUERY, multicast=True) + question2 = r.DNSQuestion(info2.type, const._TYPE_PTR, const._CLASS_IN) + query2.add_question(question2) + + zc = aiozc.zeroconf + protocol = zc.engine.protocols[0] + + with unittest.mock.patch.object(aiozc.zeroconf, "async_send") as send_mock, unittest.mock.patch.object( + protocol, "suppress_duplicate_packet", return_value=False + ): + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(0.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers + + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + await asyncio.sleep(1.2) + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers + + send_mock.reset_mock() + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) + # The delay should increase with two packets + await asyncio.sleep(1.2) + calls = send_mock.mock_calls + assert len(calls) == 0 + + await asyncio.sleep(0.63) # 620ms + 10ms for execution time + calls = send_mock.mock_calls + assert len(calls) == 1 + outgoing = send_mock.call_args[0][0] + incoming = r.DNSIncoming(outgoing.packets()[0]) + zc.handle_response(incoming) + assert info2.dns_pointer() in incoming.answers diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 20251f0b..c0e8a885 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -403,7 +403,8 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None - self._out_queue = MulticastOutgoingQueue(self) + self._out_queue = MulticastOutgoingQueue(self, 0) + self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND) self.start() @@ -740,12 +741,12 @@ def handle_assembled_query( out = construct_outgoing_multicast_answers(question_answers.mcast_aggregate) self.async_send(out) if question_answers.mcast_aggregate: - self._out_queue.async_add(now, question_answers.mcast_aggregate, 0) + self._out_queue.async_add(now, question_answers.mcast_aggregate) if question_answers.mcast_aggregate_last_second: # https://datatracker.ietf.org/doc/html/rfc6762#section-14 # If we broadcast it in the last second, we have to delay # at least a second before we send it again - self._out_queue.async_add(now, question_answers.mcast_aggregate_last_second, _ONE_SECOND) + self._out_delay_queue.async_add(now, question_answers.mcast_aggregate_last_second) def send( self, diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 76d5efcd..d2160c1a 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -499,16 +499,17 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: class MulticastOutgoingQueue: """An outgoing queue used to aggregate multicast responses.""" - def __init__(self, zeroconf: 'Zeroconf') -> None: + def __init__(self, zeroconf: 'Zeroconf', additional_delay: int) -> None: self.zc = zeroconf self.queue: deque = deque() + self.additional_delay = additional_delay - def async_add(self, now: float, answers: _AnswerWithAdditionalsType, additional_delay: int) -> None: + def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: """Add a group of answers with additionals to the outgoing queue.""" assert self.zc.loop is not None - random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + additional_delay + random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay send_after = now + random_delay - send_before = now + _MAX_MULTICAST_DELAY + additional_delay + send_before = now + _MAX_MULTICAST_DELAY + self.additional_delay if not len(self.queue): self.zc.loop.call_later(millis_to_seconds(random_delay), self._async_ready) self.queue.append(AnswerGroup(send_after, send_before, answers)) From 9942484172d7a79fe84c47924538c2c02fde7264 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 21:02:44 -0500 Subject: [PATCH 546/608] Update changelog for 0.34.1 (#943) --- README.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.rst b/README.rst index 36bcc7b3..6bc34d45 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,19 @@ See examples directory for more. Changelog ========= +0.34.1 +====== + +* Ensure multicast aggregation sends responses within 620ms (#942) @bdraco + + Responses that trigger the protection against against excessive + packet flooding due to software bugs or malicious attack described + in RFC6762 section 6 could cause the multicast aggregation response + to be delayed longer than 620ms (The maximum random delay of 120ms + and 500ms additional for aggregation). + + Only responses that trigger the protection are delayed longer than 620ms + 0.34.0 ====== From 7878a9eed93a8ec2396d8450389a08bf54bd5693 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 21:03:12 -0500 Subject: [PATCH 547/608] =?UTF-8?q?Bump=20version:=200.34.0=20=E2=86=92=20?= =?UTF-8?q?0.34.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index f1d01ab4..ff452464 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.34.0 +current_version = 0.34.1 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index bec24192..1d1fca04 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.34.0' +__version__ = '0.34.1' __license__ = 'LGPL' From 9a5164a7a3231903537231bfb56479e617355f92 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 23:23:07 -0500 Subject: [PATCH 548/608] Coalesce aggregated multicast answers when the random delay is shorter than the last scheduled response (#945) - Reduces traffic when we already know we will be sending a group of answers inside the random delay window described in https://datatracker.ietf.org/doc/html/rfc6762#section-6.3 closes #944 --- tests/test_handlers.py | 87 +++++++++++++++++++++++++++++++++++++++--- zeroconf/_handlers.py | 11 +++++- 2 files changed, 92 insertions(+), 6 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index c573c411..86a190b7 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,9 +14,10 @@ from typing import List import zeroconf as r -from zeroconf import ServiceInfo, Zeroconf, current_time_millis +from zeroconf import _handlers, ServiceInfo, Zeroconf, current_time_millis from zeroconf import const -from zeroconf._handlers import construct_outgoing_multicast_answers +from zeroconf._handlers import construct_outgoing_multicast_answers, MulticastOutgoingQueue +from zeroconf._utils.time import millis_to_seconds from zeroconf.asyncio import AsyncZeroconf @@ -1371,15 +1372,91 @@ async def test_response_aggregation_timings_multiple(run_isolated): send_mock.reset_mock() protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) protocol.datagram_received(query2.packets()[0], ('127.0.0.1', const._MDNS_PORT)) - # The delay should increase with two packets - await asyncio.sleep(1.2) + # The delay should increase with two packets and + # 900ms is beyond the maximum aggregation delay + # when there is no network protection delay + await asyncio.sleep(0.9) calls = send_mock.mock_calls assert len(calls) == 0 - await asyncio.sleep(0.63) # 620ms + 10ms for execution time + # 1000ms (1s network protection delays) + # - 900ms (already slept) + # + 120ms (maximum random delay) + # + 500ms (maximum aggregation delay) + # + 20ms (execution time) + await asyncio.sleep(millis_to_seconds(1000 - 900 + 120 + 500 + 20)) calls = send_mock.mock_calls assert len(calls) == 1 outgoing = send_mock.call_args[0][0] incoming = r.DNSIncoming(outgoing.packets()[0]) zc.handle_response(incoming) assert info2.dns_pointer() in incoming.answers + + +@pytest.mark.asyncio +async def test_response_aggregation_random_delay(): + """Verify the random delay for outgoing multicast will coalesce into a single group + + When the random delay is shorter than the last outgoing group, + the groups should be combined. + """ + type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." + type_3 = "_mservice3._tcp.local." + type_4 = "_mservice4._tcp.local." + type_5 = "_mservice5._tcp.local." + + name = "xxxyyy" + registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" + registration_name3 = f"{name}.{type_3}" + registration_name4 = f"{name}.{type_4}" + registration_name5 = f"{name}.{type_5}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) + info3 = ServiceInfo( + type_3, registration_name3, 80, 0, 0, desc, "ash-3.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info4 = ServiceInfo( + type_4, registration_name4, 80, 0, 0, desc, "ash-4.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info5 = ServiceInfo( + type_5, registration_name5, 80, 0, 0, desc, "ash-5.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + mocked_zc = unittest.mock.MagicMock() + outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0) + + now = current_time_millis() + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (500, 600)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + # The second group should always be coalesced into first group since it will always come before + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (300, 400)): + outgoing_queue.async_add(now, {info2.dns_pointer(): set()}) + + # The third group should always be coalesced into first group since it will always come before + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (100, 200)): + outgoing_queue.async_add(now, {info3.dns_pointer(): set(), info4.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 1 + assert info.dns_pointer() in outgoing_queue.queue[0].answers + assert info2.dns_pointer() in outgoing_queue.queue[0].answers + assert info3.dns_pointer() in outgoing_queue.queue[0].answers + assert info4.dns_pointer() in outgoing_queue.queue[0].answers + + # The forth group should not be coalesced because its scheduled after the last group in the queue + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (700, 800)): + outgoing_queue.async_add(now, {info5.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 2 + assert info.dns_pointer() not in outgoing_queue.queue[1].answers + assert info2.dns_pointer() not in outgoing_queue.queue[1].answers + assert info3.dns_pointer() not in outgoing_queue.queue[1].answers + assert info4.dns_pointer() not in outgoing_queue.queue[1].answers + assert info5.dns_pointer() in outgoing_queue.queue[1].answers diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index d2160c1a..73812c6f 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -510,7 +510,16 @@ def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay send_after = now + random_delay send_before = now + _MAX_MULTICAST_DELAY + self.additional_delay - if not len(self.queue): + if len(self.queue): + # If we calculate a random delay for the send after time + # that is less than the last group scheduled to go out, + # we instead add the answers to the last group as this + # allows aggregating additonal responses + last_group = self.queue[-1] + if send_after <= last_group.send_after: + last_group.answers.update(answers) + return + else: self.zc.loop.call_later(millis_to_seconds(random_delay), self._async_ready) self.queue.append(AnswerGroup(send_after, send_before, answers)) From 6d7266d0e1e6dcb950456da0354b4c43fd5c0ecb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 8 Aug 2021 23:54:13 -0500 Subject: [PATCH 549/608] Ensure ServiceInfo requests can be answered with the default timeout with network protection (#946) - Adjust the time windows to ensure responses that have triggered the protection against against excessive packet flooding due to software bugs or malicious attack described in RFC6762 section 6 can respond in under 1350ms to ensure ServiceInfo can ask two questions within the default timeout of 3000ms --- tests/test_handlers.py | 6 +++--- zeroconf/_core.py | 19 ++++++++++++++++--- zeroconf/_handlers.py | 9 ++++++--- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 86a190b7..1c180508 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1382,9 +1382,9 @@ async def test_response_aggregation_timings_multiple(run_isolated): # 1000ms (1s network protection delays) # - 900ms (already slept) # + 120ms (maximum random delay) - # + 500ms (maximum aggregation delay) + # + 200ms (maximum protected aggregation delay) # + 20ms (execution time) - await asyncio.sleep(millis_to_seconds(1000 - 900 + 120 + 500 + 20)) + await asyncio.sleep(millis_to_seconds(1000 - 900 + 120 + 200 + 20)) calls = send_mock.mock_calls assert len(calls) == 1 outgoing = send_mock.call_args[0][0] @@ -1430,7 +1430,7 @@ async def test_response_aggregation_random_delay(): type_5, registration_name5, 80, 0, 0, desc, "ash-5.local.", addresses=[socket.inet_aton("10.0.1.2")] ) mocked_zc = unittest.mock.MagicMock() - outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0) + outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 500) now = current_time_millis() with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (500, 600)): diff --git a/zeroconf/_core.py b/zeroconf/_core.py index c0e8a885..400e15a5 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -83,7 +83,20 @@ ) _TC_DELAY_RANDOM_INTERVAL = (400, 500) - +# The maximum amont of time to delay a multicast +# response in order to aggregate answers +_AGGREGATION_DELAY = 500 # ms +# The maximum amont of time to delay a multicast +# response in order to aggregate answers after +# it has already been delayed to protect the network +# from excessive traffic. We use a shorter time +# window here as we want to _try_ to answer all +# queries in under 1350ms while protecting +# the network from excessive traffic to ensure +# a service info request with two questions +# can be answered in the default timeout of +# 3000ms +_PROTECTED_AGGREGATION_DELAY = 200 # ms _CLOSE_TIMEOUT = 3000 # ms _REGISTER_BROADCASTS = 3 @@ -403,8 +416,8 @@ def __init__( self.loop: Optional[asyncio.AbstractEventLoop] = None self._loop_thread: Optional[threading.Thread] = None - self._out_queue = MulticastOutgoingQueue(self, 0) - self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND) + self._out_queue = MulticastOutgoingQueue(self, 0, _AGGREGATION_DELAY) + self._out_delay_queue = MulticastOutgoingQueue(self, _ONE_SECOND, _PROTECTED_AGGREGATION_DELAY) self.start() diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 73812c6f..2310b824 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -56,7 +56,6 @@ _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] _MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) -_MAX_MULTICAST_DELAY = 500 # ms _RESPOND_IMMEDIATE_TYPES = {_TYPE_SRV, _TYPE_A, _TYPE_AAAA} @@ -499,17 +498,21 @@ def async_remove_listener(self, listener: RecordUpdateListener) -> None: class MulticastOutgoingQueue: """An outgoing queue used to aggregate multicast responses.""" - def __init__(self, zeroconf: 'Zeroconf', additional_delay: int) -> None: + def __init__(self, zeroconf: 'Zeroconf', additional_delay: int, max_aggregation_delay: int) -> None: self.zc = zeroconf self.queue: deque = deque() + # Additional delay is used to implement + # Protect the network against excessive packet flooding + # https://datatracker.ietf.org/doc/html/rfc6762#section-14 self.additional_delay = additional_delay + self.aggregation_delay = max_aggregation_delay def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: """Add a group of answers with additionals to the outgoing queue.""" assert self.zc.loop is not None random_delay = random.randint(*_MULTICAST_DELAY_RANDOM_INTERVAL) + self.additional_delay send_after = now + random_delay - send_before = now + _MAX_MULTICAST_DELAY + self.additional_delay + send_before = now + self.aggregation_delay + self.additional_delay if len(self.queue): # If we calculate a random delay for the send after time # that is less than the last group scheduled to go out, From b87f4934b39af02f26bbbfd6f372c7154fe95906 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 Aug 2021 00:05:51 -0500 Subject: [PATCH 550/608] Update changelog for 0.34.2 (#947) --- README.rst | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.rst b/README.rst index 6bc34d45..721a7e87 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,25 @@ See examples directory for more. Changelog ========= +0.34.2 +====== + +* Coalesce aggregated multicast answers (#945) @bdraco + + When the random delay is shorter than the last scheduled response, + answers are now added to the same outgoing time group. + + This reduces traffic when we already know we will be sending a group of answers + inside the random delay window described in + datatracker.ietf.org/doc/html/rfc6762#section-6.3 +* Ensure ServiceInfo requests can be answered inside the default timeout with network protection (#946) @bdraco + + Adjust the time windows to ensure responses that have triggered the + protection against against excessive packet flooding due to + software bugs or malicious attack described in RFC6762 section 6 + can respond in under 1350ms to ensure ServiceInfo can ask two + questions within the default timeout of 3000ms + 0.34.1 ====== From 6c21f6802b58d949038e9c8501ea204eeda57a16 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 Aug 2021 00:14:39 -0500 Subject: [PATCH 551/608] =?UTF-8?q?Bump=20version:=200.34.1=20=E2=86=92=20?= =?UTF-8?q?0.34.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index ff452464..ed8435b7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.34.1 +current_version = 0.34.2 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 1d1fca04..3b953047 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.34.1' +__version__ = '0.34.2' __license__ = 'LGPL' From 02af7f78d2e5eabcc5cce8238546ee5170951b28 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 Aug 2021 00:55:48 -0500 Subject: [PATCH 552/608] Fix sending immediate multicast responses (#949) - Fixes a typo in handle_assembled_query that prevented immediate responses from being sent. --- zeroconf/_core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 400e15a5..dc3a060f 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -751,8 +751,7 @@ def handle_assembled_query( out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) self.async_send(out, addr, port, v6_flow_scope) if question_answers.mcast_now: - out = construct_outgoing_multicast_answers(question_answers.mcast_aggregate) - self.async_send(out) + self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) if question_answers.mcast_aggregate: self._out_queue.async_add(now, question_answers.mcast_aggregate) if question_answers.mcast_aggregate_last_second: From 23b00e983b2e8335431dcc074935f379fd399d46 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 Aug 2021 00:57:54 -0500 Subject: [PATCH 553/608] Update changelog for 0.34.3 (#950) --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index 721a7e87..277d790c 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,11 @@ See examples directory for more. Changelog ========= +0.34.3 +====== + +* Fix sending immediate multicast responses (#949) @bdraco + 0.34.2 ====== From 9d69d18713bdfab53762a6b8c3aff7fd72ebd025 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 9 Aug 2021 00:58:09 -0500 Subject: [PATCH 554/608] =?UTF-8?q?Bump=20version:=200.34.2=20=E2=86=92=20?= =?UTF-8?q?0.34.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index ed8435b7..14216383 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.34.2 +current_version = 0.34.3 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 3b953047..2f17b31f 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.34.2' +__version__ = '0.34.3' __license__ = 'LGPL' From ebc23ee5e9592dd7f0235cd57f9b3ad727ec8bff Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 12 Aug 2021 21:42:50 -0500 Subject: [PATCH 555/608] Sort responses to increase chance of name compression (#954) - When building an outgoing response, sort the names together to increase the likelihood of name compression. In testing this reduced the number of packets for large responses (from 7 packets to 6) --- zeroconf/_handlers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 2310b824..cff12870 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -100,12 +100,16 @@ def construct_outgoing_unicast_answers( def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsType) -> None: # Find additionals and suppress any additionals that are already in answers - additionals: Set[DNSRecord] = set().union(*answers.values()) # type: ignore - additionals -= answers.keys() - for answer in answers: + sending: Set[DNSRecord] = set(answers.keys()) + # Answers are sorted to group names together to increase the chance + # that similar names will end up in the same packet and can reduce the + # overall size of the outgoing response via name compression + for answer, additionals in sorted(answers.items(), key=lambda kv: kv[0].name): out.add_answer_at_time(answer, 0) - for additional in additionals: - out.add_additional_answer(additional) + for additional in additionals: + if additional not in sending: + out.add_additional_answer(additional) + sending.add(additional) def sanitize_incoming_record(record: DNSRecord) -> None: From 5fb3e202c06e3a0d30e3c7824397d8e8a9f52555 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 13 Aug 2021 08:25:59 -0500 Subject: [PATCH 556/608] Send unicast replies on the same socket the query was received (#952) When replying to a QU question, we do not know if the sending host is reachable from all of the sending sockets. We now avoid this problem by replying via the receiving socket. This was the existing behavior when `InterfaceChoice.Default` is set. This change extends the unicast relay behavior to used with `InterfaceChoice.Default` to apply when `InterfaceChoice.All` or interfaces are explicitly passed when instantiating a `Zeroconf` instance. Fixes #951 --- tests/test_asyncio.py | 6 ++- tests/test_core.py | 18 ++++----- zeroconf/_core.py | 91 +++++++++++++++++++++++++++++-------------- 3 files changed, 74 insertions(+), 41 deletions(-) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 39cad5b9..7b895386 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -974,12 +974,14 @@ async def test_legacy_unicast_response(run_isolated): query = DNSOutgoing(const._FLAGS_QR_QUERY, multicast=False, id_=888) question = DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN) query.add_question(question) + protocol = aiozc.zeroconf.engine.protocols[0] with patch.object(aiozc.zeroconf, "async_send") as send_mock: - aiozc.zeroconf.engine.protocols[0].datagram_received(query.packets()[0], ('127.0.0.1', 6503)) + protocol.datagram_received(query.packets()[0], ('127.0.0.1', 6503)) calls = send_mock.mock_calls - assert calls == [call(ANY, '127.0.0.1', 6503, ())] + # Verify the response is sent back on the socket it was recieved from + assert calls == [call(ANY, '127.0.0.1', 6503, (), protocol.transport)] outgoing = send_mock.call_args[0][0] assert isinstance(outgoing, DNSOutgoing) assert outgoing.questions == [question] diff --git a/tests/test_core.py b/tests/test_core.py index e2420a78..ba1effac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -480,28 +480,28 @@ def test_tc_bit_defers(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert source_ip not in protocol._deferred assert source_ip not in protocol._timers @@ -559,13 +559,13 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred timer1 = protocol._timers[source_ip] next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred timer2 = protocol._timers[source_ip] if sys.version_info >= (3, 7): @@ -573,7 +573,7 @@ def test_tc_bit_defers_last_response_missing(): assert timer2 != timer1 # Send the same packet again to similar multi interfaces - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer3 = protocol._timers[source_ip] @@ -583,7 +583,7 @@ def test_tc_bit_defers_last_response_missing(): next_packet = r.DNSIncoming(packets.pop(0)) expected_deferred.append(next_packet) - threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT) + threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, None) assert protocol._deferred[source_ip] == expected_deferred assert source_ip in protocol._timers timer4 = protocol._timers[source_ip] diff --git a/zeroconf/_core.py b/zeroconf/_core.py index dc3a060f..72c6e4ce 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -215,7 +215,8 @@ def __init__(self, zc: 'Zeroconf') -> None: self.data: Optional[bytes] = None self.last_time: float = 0 self.transport: Optional[asyncio.DatagramTransport] = None - + self.sock_name: Optional[str] = None + self.sock_fileno: Optional[int] = None self._deferred: Dict[str, List[DNSIncoming]] = {} self._timers: Dict[str, asyncio.TimerHandle] = {} @@ -294,15 +295,20 @@ def datagram_received( self.zc.handle_response(msg) return - self.handle_query_or_defer(msg, addr, port, v6_flow_scope) + self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) def handle_query_or_defer( - self, msg: DNSIncoming, addr: str, port: int, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + self, + msg: DNSIncoming, + addr: str, + port: int, + transport: asyncio.DatagramTransport, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Deal with incoming query packets. Provides a response if possible.""" if not msg.truncated: - self._respond_query(msg, addr, port, v6_flow_scope) + self._respond_query(msg, addr, port, transport, v6_flow_scope) return deferred = self._deferred.setdefault(addr, []) @@ -315,7 +321,7 @@ def handle_query_or_defer( assert self.zc.loop is not None self._cancel_any_timers_for_addr(addr) self._timers[addr] = self.zc.loop.call_later( - delay, self._respond_query, None, addr, port, v6_flow_scope + delay, self._respond_query, None, addr, port, transport, v6_flow_scope ) def _cancel_any_timers_for_addr(self, addr: str) -> None: @@ -328,6 +334,7 @@ def _respond_query( msg: Optional[DNSIncoming], addr: str, port: int, + transport: asyncio.DatagramTransport, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Respond to a query and reassemble any truncated deferred packets.""" @@ -336,15 +343,12 @@ def _respond_query( if msg: packets.append(msg) - self.zc.handle_assembled_query(packets, addr, port, v6_flow_scope) + self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) @property def _socket_description(self) -> str: """A human readable description of the socket.""" - assert self.transport is not None - fileno = self.transport.get_extra_info('socket').fileno() - sockname = self.transport.get_extra_info('sockname') - return f"{fileno} ({sockname})" + return f"{self.sock_fileno} ({self.sock_name})" def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" @@ -357,6 +361,8 @@ def error_received(self, exc: Exception) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) + self.sock_name = self.transport.get_extra_info('sockname') + self.sock_fileno = self.transport.get_extra_info('socket').fileno() def connection_lost(self, exc: Optional[Exception]) -> None: """Handle connection lost.""" @@ -400,6 +406,7 @@ def __init__( if apple_p2p and sys.platform != 'darwin': raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') + self.unicast = unicast listen_socket, respond_sockets = create_sockets(interfaces, unicast, ip_version, apple_p2p=apple_p2p) log.debug('Listen socket %s, respond sockets %s', listen_socket, respond_sockets) @@ -732,6 +739,7 @@ def handle_assembled_query( packets: List[DNSIncoming], addr: str, port: int, + transport: asyncio.DatagramTransport, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: """Respond to a (re)assembled query. @@ -749,7 +757,10 @@ def handle_assembled_query( questions = packets[0].questions id_ = packets[0].id out = construct_outgoing_unicast_answers(question_answers.ucast, ucast_source, questions, id_) - self.async_send(out, addr, port, v6_flow_scope) + # When sending unicast, only send back the reply + # via the same socket that it was recieved from + # as we know its reachable from that socket + self.async_send(out, addr, port, v6_flow_scope, transport) if question_answers.mcast_now: self.async_send(construct_outgoing_multicast_answers(question_answers.mcast_now)) if question_answers.mcast_aggregate: @@ -766,10 +777,11 @@ def send( addr: Optional[str] = None, port: int = _MDNS_PORT, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[asyncio.DatagramTransport] = None, ) -> None: """Sends an outgoing packet threadsafe.""" assert self.loop is not None - self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope) + self.loop.call_soon_threadsafe(self.async_send, out, addr, port, v6_flow_scope, transport) def async_send( self, @@ -777,33 +789,52 @@ def async_send( addr: Optional[str] = None, port: int = _MDNS_PORT, v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + transport: Optional[asyncio.DatagramTransport] = None, ) -> None: """Sends an outgoing packet.""" if self._GLOBAL_DONE: return + # If no transport is specified, we send to all the ones + # with the same address family + transports = [transport] if transport else self.engine.senders + for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return - log.debug( - 'Sending to (%s, %d) (%d bytes #%d) %r as %r...', - addr, - port, - len(packet), - packet_num + 1, - out, - packet, - ) - for transport in self.engine.senders: - s = transport.get_extra_info('socket') - if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR - elif not can_send_to(s, addr): - continue - else: - real_addr = addr - transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) + for send_transport in transports: + self._async_send_transport(send_transport, packet, packet_num, out, addr, port, v6_flow_scope) + + def _async_send_transport( + self, + transport: asyncio.DatagramTransport, + packet: bytes, + packet_num: int, + out: DNSOutgoing, + addr: Optional[str], + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), + ) -> None: + s = transport.get_extra_info('socket') + if addr is None: + real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR + else: + real_addr = addr + if not can_send_to(s, real_addr): + return + log.debug( + 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', + real_addr, + port or _MDNS_PORT, + s.fileno(), + transport.get_extra_info('sockname'), + len(packet), + packet_num + 1, + out, + packet, + ) + transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) def _close(self) -> None: """Set global done and remove all service listeners.""" From c77293692062ea701037e06c1cf5497f019ae2f2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 13 Aug 2021 09:11:51 -0500 Subject: [PATCH 557/608] Reduce chance of accidental synchronization of ServiceInfo requests (#955) --- tests/services/test_info.py | 22 ++++++++++++++++++++++ zeroconf/_services/info.py | 12 ++++++++++++ 2 files changed, 34 insertions(+) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 0464ae7b..2143b5fe 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -754,3 +754,25 @@ def test_request_timeout(): # 3000ms for the default timeout # 1000ms for loaded systems + schedule overhead assert (end_time - start_time) < 3000 + 1000 + + +@pytest.mark.asyncio +async def test_we_try_four_times_with_random_delay(): + """Verify we try four times even with the random delay.""" + type_ = "_typethatisnothere._tcp.local." + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + + # we are going to patch the zeroconf send to check query transmission + request_count = 0 + def async_send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): + """Sends an outgoing packet.""" + nonlocal request_count + request_count += 1 + + # patch the zeroconf send + with patch.object(aiozc.zeroconf, "async_send", async_send): + await aiozc.async_get_service_info(f"willnotbefound.{type_}", type_) + + await aiozc.async_close() + + assert request_count == 4 diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index cede3877..33c0488a 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -21,6 +21,7 @@ """ import ipaddress +import random import socket from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, cast @@ -52,6 +53,16 @@ ) +# https://datatracker.ietf.org/doc/html/rfc6762#section-5.2 +# The most common case for calling ServiceInfo is from a +# ServiceBrowser. After the first request we add a few random +# milliseconds to the delay between requests to reduce the chance +# that there are multiple ServiceBrowser callbacks running on +# the network that are firing at the same time when they +# see the same multicast response and decide to refresh +# the A/AAAA/SRV records for a host. +_AVOID_SYNC_DELAY_RANDOM_INTERVAL = (20, 120) + if TYPE_CHECKING: from .._core import Zeroconf @@ -455,6 +466,7 @@ async def async_request( zc.async_send(out) next_ = now + delay delay *= 2 + next_ += random.randint(*_AVOID_SYNC_DELAY_RANDOM_INTERVAL) await zc.async_wait(min(next_, last) - now) now = current_time_millis() From dd40437f4328f4ee36c43239ecf5f484b6ac261e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 13 Aug 2021 13:58:11 -0500 Subject: [PATCH 558/608] Update changelog for 0.35.0 (#957) --- README.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/README.rst b/README.rst index 277d790c..e3f99f0d 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,27 @@ See examples directory for more. Changelog ========= +0.35.0 +====== + +* Reduced chance of accidental synchronization of ServiceInfo requests (#955) @bdraco +* Sort aggregated responses to increase chance of name compression (#954) @bdraco + +Technically backwards incompatible: + +* Send unicast replies on the same socket the query was received (#952) @bdraco + + When replying to a QU question, we do not know if the sending host is reachable + from all of the sending sockets. We now avoid this problem by replying via + the receiving socket. This was the existing behavior when `InterfaceChoice.Default` + is set. + + This change extends the unicast relay behavior to used with `InterfaceChoice.Default` + to apply when `InterfaceChoice.All` or interfaces are explicitly passed when + instantiating a `Zeroconf` instance. + + Fixes #951 + 0.34.3 ====== From 1e60e13ae15a5b533a48cc955b98951eedd04dbb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 13 Aug 2021 14:11:22 -0500 Subject: [PATCH 559/608] =?UTF-8?q?Bump=20version:=200.34.3=20=E2=86=92=20?= =?UTF-8?q?0.35.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 14216383..cb6377b9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.34.3 +current_version = 0.35.0 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 2f17b31f..d71537f6 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.34.3' +__version__ = '0.35.0' __license__ = 'LGPL' From 7b125a1a0a109ef29d0a4e736a27645a7e9b4207 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 09:27:48 -0500 Subject: [PATCH 560/608] Only reschedule types if the send next time changes (#958) - When the PTR response was seen again, the timer was being canceled and rescheduled even if the timer was for the same time. While this did not cause any breakage, it is quite inefficient. --- tests/services/test_browser.py | 2 +- zeroconf/_services/browser.py | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/services/test_browser.py b/tests/services/test_browser.py index 292dee25..e22ebfe3 100644 --- a/tests/services/test_browser.py +++ b/tests/services/test_browser.py @@ -496,7 +496,7 @@ def on_service_state_change(zeroconf, service_type, state_change, name): else: assert not got_query.is_set() time_offset += initial_query_interval - zeroconf_browser.loop.call_soon_threadsafe(browser.schedule_changed) + zeroconf_browser.loop.call_soon_threadsafe(browser._async_send_ready_queries_schedule_next) finally: browser.cancel() diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index 51f2c8d5..aadbd7ac 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -224,11 +224,12 @@ def millis_to_wait(self, now: float) -> float: next_time = min(self._next_time.values()) return 0 if next_time <= now else next_time - now - def reschedule_type(self, type_: str, next_time: float) -> None: + def reschedule_type(self, type_: str, next_time: float) -> bool: """Reschedule the query for a type to happen sooner.""" if next_time >= self._next_time[type_]: - return + return False self._next_time[type_] = next_time + return True def process_ready_types(self, now: float) -> List[str]: """Generate a list of ready types that is due and schedule the next time.""" @@ -449,7 +450,8 @@ def _generate_ready_queries(self, first_request: bool) -> List[DNSOutgoing]: async def _async_start_query_sender(self) -> None: """Start scheduling queries.""" await self.zc.async_wait_for_start() - self._async_send_ready_queries_schedule_next() + self._async_send_ready_queries() + self._async_schedule_next() def _cancel_send_timer(self) -> None: """Cancel the next send.""" @@ -458,16 +460,13 @@ def _cancel_send_timer(self) -> None: def reschedule_type(self, type_: str, next_time: float) -> None: """Reschedule a type to be refreshed in the future.""" - self.query_scheduler.reschedule_type(type_, next_time) - self.schedule_changed() - - def schedule_changed(self) -> None: - """Called when the schedule has changed.""" - self._cancel_send_timer() - self._async_send_ready_queries_schedule_next() + if self.query_scheduler.reschedule_type(type_, next_time): + self._cancel_send_timer() + self._async_schedule_next() + self._async_send_ready_queries() - def _async_send_ready_queries_schedule_next(self) -> None: - """Send any ready queries and scheule the next time.""" + def _async_send_ready_queries(self) -> None: + """Send any ready queries.""" if self.done or self.zc.done: return @@ -477,6 +476,13 @@ def _async_send_ready_queries_schedule_next(self) -> None: for out in outs: self.zc.async_send(out, addr=self.addr, port=self.port) + def _async_send_ready_queries_schedule_next(self) -> None: + """Send ready queries and schedule next one.""" + self._async_send_ready_queries() + self._async_schedule_next() + + def _async_schedule_next(self) -> None: + """Scheule the next time.""" assert self.zc.loop is not None delay = millis_to_seconds(self.query_scheduler.millis_to_wait(current_time_millis())) self._next_send_timer = self.zc.loop.call_later(delay, self._async_send_ready_queries_schedule_next) From 2d1b8329ad39b94f9f4aa5f53caf3bb2813879ca Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 10:15:54 -0500 Subject: [PATCH 561/608] Add coverage for sending answers removes future queued answers (#961) - If we send an answer that is queued to be sent out in the future we should remove it from the queue as the question has already been answered and we do not want to generate additional traffic. --- tests/services/test_info.py | 1 + tests/test_handlers.py | 39 +++++++++++++++++++++++++++++++++++++ zeroconf/_handlers.py | 21 +++++++++++--------- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/tests/services/test_info.py b/tests/services/test_info.py index 2143b5fe..a72d82f9 100644 --- a/tests/services/test_info.py +++ b/tests/services/test_info.py @@ -764,6 +764,7 @@ async def test_we_try_four_times_with_random_delay(): # we are going to patch the zeroconf send to check query transmission request_count = 0 + def async_send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT): """Sends an outgoing packet.""" nonlocal request_count diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 1c180508..11ea03f9 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1460,3 +1460,42 @@ async def test_response_aggregation_random_delay(): assert info3.dns_pointer() not in outgoing_queue.queue[1].answers assert info4.dns_pointer() not in outgoing_queue.queue[1].answers assert info5.dns_pointer() in outgoing_queue.queue[1].answers + + +@pytest.mark.asyncio +async def test_future_answers_are_removed_on_send(): + """Verify any future answers scheduled to be sent are removed when we send.""" + type_ = "_mservice._tcp.local." + name = "xxxyyy" + registration_name = f"{name}.{type_}" + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + mocked_zc = unittest.mock.MagicMock() + outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 0) + + now = current_time_millis() + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (10, 10)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 1 + + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (20, 20)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 2 + + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (200, 200)): + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + + assert len(outgoing_queue.queue) == 3 + + await asyncio.sleep(0.1) + outgoing_queue.async_ready() + + assert len(outgoing_queue.queue) == 1 + # The answers should all get removed because we just sent them + assert len(outgoing_queue.queue[0].answers) == 0 diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index cff12870..06ed54cd 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -527,10 +527,16 @@ def async_add(self, now: float, answers: _AnswerWithAdditionalsType) -> None: last_group.answers.update(answers) return else: - self.zc.loop.call_later(millis_to_seconds(random_delay), self._async_ready) + self.zc.loop.call_later(millis_to_seconds(random_delay), self.async_ready) self.queue.append(AnswerGroup(send_after, send_before, answers)) - def _async_ready(self) -> None: + def _remove_answers_from_queue(self, answers: _AnswerWithAdditionalsType) -> None: + """Remove a set of answers from the outgoing queue.""" + for pending in self.queue: + for record in answers: + pending.answers.pop(record, None) + + def async_ready(self) -> None: """Process anything in the queue that is ready.""" assert self.zc.loop is not None now = current_time_millis() @@ -538,7 +544,7 @@ def _async_ready(self) -> None: if len(self.queue) > 1 and self.queue[0].send_before > now: # There is more than one answer in the queue, # delay until we have to send it (first answer group reaches send_before) - self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self._async_ready) + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_before - now), self.async_ready) return answers: _AnswerWithAdditionalsType = {} @@ -549,12 +555,9 @@ def _async_ready(self) -> None: if len(self.queue): # If there are still groups in the queue that are not ready to send # be sure we schedule them to go out later - self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self._async_ready) + self.zc.loop.call_later(millis_to_seconds(self.queue[0].send_after - now), self.async_ready) if answers: - # If we have the same answer scheduled to go out, remove it - for pending in self.queue: - for record in answers: - pending.answers.pop(record, None) - + # If we have the same answer scheduled to go out, remove them + self._remove_answers_from_queue(answers) self.zc.async_send(construct_outgoing_multicast_answers(answers)) From 3b482e229d37b85e59765e023ddbca77aa513731 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 10:42:14 -0500 Subject: [PATCH 562/608] Fix flakey test: test_future_answers_are_removed_on_send (#962) --- tests/test_handlers.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 11ea03f9..a621f037 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -1466,29 +1466,34 @@ async def test_response_aggregation_random_delay(): async def test_future_answers_are_removed_on_send(): """Verify any future answers scheduled to be sent are removed when we send.""" type_ = "_mservice._tcp.local." + type_2 = "_mservice2._tcp.local." name = "xxxyyy" registration_name = f"{name}.{type_}" + registration_name2 = f"{name}.{type_2}" desc = {'path': '/~paulsm/'} info = ServiceInfo( type_, registration_name, 80, 0, 0, desc, "ash-1.local.", addresses=[socket.inet_aton("10.0.1.2")] ) + info2 = ServiceInfo( + type_2, registration_name2, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.3")] + ) mocked_zc = unittest.mock.MagicMock() outgoing_queue = MulticastOutgoingQueue(mocked_zc, 0, 0) now = current_time_millis() - with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (10, 10)): + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1, 1)): outgoing_queue.async_add(now, {info.dns_pointer(): set()}) assert len(outgoing_queue.queue) == 1 - with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (20, 20)): + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (2, 2)): outgoing_queue.async_add(now, {info.dns_pointer(): set()}) assert len(outgoing_queue.queue) == 2 - with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (200, 200)): - outgoing_queue.async_add(now, {info.dns_pointer(): set()}) + with unittest.mock.patch.object(_handlers, "_MULTICAST_DELAY_RANDOM_INTERVAL", (1000, 1000)): + outgoing_queue.async_add(now, {info2.dns_pointer(): set()}) outgoing_queue.async_add(now, {info.dns_pointer(): set()}) assert len(outgoing_queue.queue) == 3 @@ -1497,5 +1502,8 @@ async def test_future_answers_are_removed_on_send(): outgoing_queue.async_ready() assert len(outgoing_queue.queue) == 1 - # The answers should all get removed because we just sent them - assert len(outgoing_queue.queue[0].answers) == 0 + # The answer should get removed because we just sent it + assert info.dns_pointer() not in outgoing_queue.queue[0].answers + + # But the one we have not sent yet shoudl still go out later + assert info2.dns_pointer() in outgoing_queue.queue[0].answers From d4c109c3abffcba2331a7f9e7bf45c6477a8d4e8 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 10:42:22 -0500 Subject: [PATCH 563/608] Cache DNS record and question hashes (#960) --- tests/test_asyncio.py | 1 + tests/test_dns.py | 27 ++++++++++++++++++++++++ zeroconf/_dns.py | 48 +++++++++++++++++++++++++++---------------- 3 files changed, 58 insertions(+), 18 deletions(-) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 7b895386..e6da20a6 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -579,6 +579,7 @@ async def test_async_unregister_all_services() -> None: assert results[1] is not None await aiozc.async_unregister_all_services() + _clear_cache(aiozc.zeroconf) tasks = [] tasks.append(aiozc.async_get_service_info(type_, registration_name)) diff --git a/tests/test_dns.py b/tests/test_dns.py index a952b81e..fe3efda8 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -163,6 +163,33 @@ def test_dns_record_is_recent(self): assert record.is_recent(now + (8 * 1000)) is False +def test_dns_question_hashablity(): + """Test DNSQuestions are hashable.""" + + record1 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + record2 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + + record_set = {record1, record2} + assert len(record_set) == 1 + + record_set.add(record1) + assert len(record_set) == 1 + + record3_dupe = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN) + assert record2 == record3_dupe + assert record2.__hash__() == record3_dupe.__hash__() + + record_set.add(record3_dupe) + assert len(record_set) == 1 + + record4_dupe = r.DNSQuestion('notsame', const._TYPE_A, const._CLASS_IN) + assert record2 != record4_dupe + assert record2.__hash__() != record4_dupe.__hash__() + + record_set.add(record4_dupe) + assert len(record_set) == 2 + + def test_dns_record_hashablity_does_not_consider_ttl(): """Test DNSRecord are hashable.""" diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 5b211060..a9bc7d77 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -22,7 +22,7 @@ import enum import socket -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union, cast from ._exceptions import AbstractMethodException from ._utils.net import _is_v6_address @@ -81,10 +81,6 @@ def __init__(self, name: str, type_: int, class_: int) -> None: self.class_ = class_ & _CLASS_MASK self.unique = (class_ & _CLASS_UNIQUE) != 0 - def _entry_tuple(self) -> Tuple[str, int, int]: - """Entry Tuple for DNSEntry.""" - return (self.key, self.type, self.class_) - def __eq__(self, other: Any) -> bool: """Equality test on key (lowercase name), type, and class""" return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry) @@ -115,12 +111,22 @@ class DNSQuestion(DNSEntry): """A DNS question entry""" + __slots__ = ('_hash',) + + def __init__(self, name: str, type_: int, class_: int) -> None: + super().__init__(name, type_, class_) + self._hash = hash((self.key, type_, class_)) + def answered_by(self, rec: 'DNSRecord') -> bool: """Returns true if the question is answered by the record""" return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name def __hash__(self) -> int: - return hash((self.name, self.class_, self.type)) + return self._hash + + def __eq__(self, other: Any) -> bool: + """Tests equality on dns question.""" + return isinstance(other, DNSQuestion) and DNSEntry.__eq__(self, other) @property def max_size(self) -> int: @@ -225,7 +231,7 @@ class DNSAddress(DNSRecord): """A DNS address record""" - __slots__ = ('address', 'scope_id') + __slots__ = ('_hash', 'address', 'scope_id') def __init__( self, @@ -241,6 +247,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.address = address self.scope_id = scope_id + self._hash = hash((self.key, type_, class_, address, scope_id)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -257,7 +264,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: """Hash to compare like DNSAddresses.""" - return hash((*self._entry_tuple(), self.address, self.scope_id)) + return self._hash def __repr__(self) -> str: """String representation""" @@ -275,7 +282,7 @@ class DNSHinfo(DNSRecord): """A DNS host information record""" - __slots__ = ('cpu', 'os') + __slots__ = ('_hash', 'cpu', 'os') def __init__( self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None @@ -283,6 +290,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.cpu = cpu self.os = os + self._hash = hash((self.key, type_, class_, cpu, os)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -300,7 +308,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: """Hash to compare like DNSHinfo.""" - return hash((*self._entry_tuple(), self.cpu, self.os)) + return self._hash def __repr__(self) -> str: """String representation""" @@ -311,13 +319,14 @@ class DNSPointer(DNSRecord): """A DNS pointer record""" - __slots__ = ('alias',) + __slots__ = ('_hash', 'alias') def __init__( self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None ) -> None: super().__init__(name, type_, class_, ttl, created) self.alias = alias + self._hash = hash((self.key, type_, class_, alias)) @property def max_size_compressed(self) -> int: @@ -339,7 +348,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: """Hash to compare like DNSPointer.""" - return hash((*self._entry_tuple(), self.alias)) + return self._hash def __repr__(self) -> str: """String representation""" @@ -350,7 +359,7 @@ class DNSText(DNSRecord): """A DNS text record""" - __slots__ = ('text',) + __slots__ = ('_hash', 'text') def __init__( self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None @@ -358,6 +367,7 @@ def __init__( assert isinstance(text, (bytes, type(None))) super().__init__(name, type_, class_, ttl, created) self.text = text + self._hash = hash((self.key, type_, class_, text)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -365,7 +375,7 @@ def write(self, out: 'DNSOutgoing') -> None: def __hash__(self) -> int: """Hash to compare like DNSText.""" - return hash((*self._entry_tuple(), self.text)) + return self._hash def __eq__(self, other: Any) -> bool: """Tests equality on text""" @@ -382,7 +392,7 @@ class DNSService(DNSRecord): """A DNS service record""" - __slots__ = ('priority', 'weight', 'port', 'server') + __slots__ = ('_hash', 'priority', 'weight', 'port', 'server') def __init__( self, @@ -401,6 +411,7 @@ def __init__( self.weight = weight self.port = port self.server = server + self._hash = hash((self.key, type_, class_, priority, weight, port, server)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -422,7 +433,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: """Hash to compare like DNSService.""" - return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server)) + return self._hash def __repr__(self) -> str: """String representation""" @@ -433,7 +444,7 @@ class DNSNsec(DNSRecord): """A DNS NSEC record""" - __slots__ = ('next_name', 'rdtypes') + __slots__ = ('_hash', 'next_name', 'rdtypes') def __init__( self, @@ -448,6 +459,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.next_name = next_name self.rdtypes = rdtypes + self._hash = hash((self.key, type_, class_, next_name, *self.rdtypes)) def __eq__(self, other: Any) -> bool: """Tests equality on cpu and os""" @@ -460,7 +472,7 @@ def __eq__(self, other: Any) -> bool: def __hash__(self) -> int: """Hash to compare like DNSNSec.""" - return hash((*self._entry_tuple(), self.next_name, *self.rdtypes)) + return self._hash def __repr__(self) -> str: """String representation""" From f7bebfe09aeb9bb973dbe6ba147b682472b64246 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 14:11:54 -0500 Subject: [PATCH 564/608] Update changelog for 0.35.1 (#963) --- README.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.rst b/README.rst index e3f99f0d..fbf2e4fe 100644 --- a/README.rst +++ b/README.rst @@ -140,6 +140,18 @@ See examples directory for more. Changelog ========= +0.35.1 +====== + +* Only reschedule types if the send next time changes (#958) @bdraco + When the PTR response was seen again, the timer was being canceled and + rescheduled even if the timer was for the same time. While this did + not cause any breakage, it is quite inefficient. +* Cache DNS record and question hashes (#960) @bdraco + The hash was being recalculated every time the object + was being used in a set or dict. Since the hashes are + effectively immutable, we only calculate them once now. + 0.35.0 ====== From c7c7d4778e9962af5180616af73977d8503e4762 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 14:13:41 -0500 Subject: [PATCH 565/608] Fix formatting in 0.35.1 changelog entry (#964) --- README.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.rst b/README.rst index fbf2e4fe..4b14550d 100644 --- a/README.rst +++ b/README.rst @@ -144,10 +144,12 @@ Changelog ====== * Only reschedule types if the send next time changes (#958) @bdraco + When the PTR response was seen again, the timer was being canceled and rescheduled even if the timer was for the same time. While this did not cause any breakage, it is quite inefficient. * Cache DNS record and question hashes (#960) @bdraco + The hash was being recalculated every time the object was being used in a set or dict. Since the hashes are effectively immutable, we only calculate them once now. From 4281221b668123b770c6d6b0835dd876d1d2f22d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 Aug 2021 14:14:10 -0500 Subject: [PATCH 566/608] =?UTF-8?q?Bump=20version:=200.35.0=20=E2=86=92=20?= =?UTF-8?q?0.35.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index cb6377b9..95f97b88 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.35.0 +current_version = 0.35.1 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index d71537f6..b90d2a99 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.35.0' +__version__ = '0.35.1' __license__ = 'LGPL' From 733eb3a31ed40c976f5fa4b7b3baf055589ef36b Mon Sep 17 00:00:00 2001 From: Lokesh Date: Mon, 16 Aug 2021 20:26:26 +0100 Subject: [PATCH 567/608] Create full IPv6 address tuple to enable service discovery on Windows (#965) --- README.rst | 2 -- zeroconf/_core.py | 5 +++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 4b14550d..2b176e29 100644 --- a/README.rst +++ b/README.rst @@ -75,8 +75,6 @@ IPv6 support is relatively new and currently limited, specifically: * `InterfaceChoice.All` is an alias for `InterfaceChoice.Default` on non-POSIX systems. -* On Windows specific interfaces can only be requested as interface indexes, - not as IP addresses. * Dual-stack IPv6 sockets are used, which may not be supported everywhere (some BSD variants do not have them). * Listening on localhost (`::1`) does not work. Help with understanding why is diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 72c6e4ce..96b1a790 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -834,6 +834,11 @@ def _async_send_transport( out, packet, ) + # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 + # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families + if s.family == socket.AF_INET6 and not v6_flow_scope: + _, _, sock_flowinfo, sock_scopeid = s.getsockname() + v6_flow_scope = (sock_flowinfo, sock_scopeid) transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) def _close(self) -> None: From bc50bce04b650756fef3f8b1cce6defbc5dccee5 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 16 Aug 2021 14:44:00 -0500 Subject: [PATCH 568/608] Update changelog for 0.36.0 (#966) --- README.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.rst b/README.rst index 2b176e29..a766e58a 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,15 @@ See examples directory for more. Changelog ========= +0.36.0 +====== + +Technically backwards incompatible: + +* Fill incomplete IPv6 tuples to avoid WinError on windows (#965) @lokesh2019 + + Fixed #932 + 0.35.1 ====== From e4985c7dd2088d4da9fc2be25f67beb65f548e95 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 16 Aug 2021 14:44:39 -0500 Subject: [PATCH 569/608] =?UTF-8?q?Bump=20version:=200.35.1=20=E2=86=92=20?= =?UTF-8?q?0.36.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 95f97b88..716f4661 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.35.1 +current_version = 0.36.0 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index b90d2a99..22e0af99 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.35.1' +__version__ = '0.36.0' __license__ = 'LGPL' From 574e24125a536dc4fb9a1784797efd495ceb1fdf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 29 Aug 2021 13:08:58 -0500 Subject: [PATCH 570/608] Fix equality and hash for dns records with the unique bit (#969) --- tests/test_dns.py | 15 +++++++++++++++ zeroconf/_dns.py | 14 +++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/test_dns.py b/tests/test_dns.py index fe3efda8..c2669205 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -211,6 +211,21 @@ def test_dns_record_hashablity_does_not_consider_ttl(): assert len(record_set) == 1 +def test_dns_record_hashablity_does_not_consider_unique(): + """Test DNSRecord are hashable and unique is ignored.""" + + # Verify the unique value is not considered in the hash + record1 = r.DNSAddress( + 'irrelevant', const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, const._DNS_OTHER_TTL, b'same' + ) + record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same') + + assert record1.class_ == record2.class_ + assert record1.__hash__() == record2.__hash__() + record_set = {record1, record2} + assert len(record_set) == 1 + + def test_dns_address_record_hashablity(): """Test DNSAddress are hashable.""" address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a') diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index a9bc7d77..0d09a421 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -115,7 +115,7 @@ class DNSQuestion(DNSEntry): def __init__(self, name: str, type_: int, class_: int) -> None: super().__init__(name, type_, class_) - self._hash = hash((self.key, type_, class_)) + self._hash = hash((self.key, type_, self.class_)) def answered_by(self, rec: 'DNSRecord') -> bool: """Returns true if the question is answered by the record""" @@ -247,7 +247,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.address = address self.scope_id = scope_id - self._hash = hash((self.key, type_, class_, address, scope_id)) + self._hash = hash((self.key, type_, self.class_, address, scope_id)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -290,7 +290,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.cpu = cpu self.os = os - self._hash = hash((self.key, type_, class_, cpu, os)) + self._hash = hash((self.key, type_, self.class_, cpu, os)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -326,7 +326,7 @@ def __init__( ) -> None: super().__init__(name, type_, class_, ttl, created) self.alias = alias - self._hash = hash((self.key, type_, class_, alias)) + self._hash = hash((self.key, type_, self.class_, alias)) @property def max_size_compressed(self) -> int: @@ -367,7 +367,7 @@ def __init__( assert isinstance(text, (bytes, type(None))) super().__init__(name, type_, class_, ttl, created) self.text = text - self._hash = hash((self.key, type_, class_, text)) + self._hash = hash((self.key, type_, self.class_, text)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -411,7 +411,7 @@ def __init__( self.weight = weight self.port = port self.server = server - self._hash = hash((self.key, type_, class_, priority, weight, port, server)) + self._hash = hash((self.key, type_, self.class_, priority, weight, port, server)) def write(self, out: 'DNSOutgoing') -> None: """Used in constructing an outgoing packet""" @@ -459,7 +459,7 @@ def __init__( super().__init__(name, type_, class_, ttl, created) self.next_name = next_name self.rdtypes = rdtypes - self._hash = hash((self.key, type_, class_, next_name, *self.rdtypes)) + self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) def __eq__(self, other: Any) -> bool: """Tests equality on cpu and os""" From d9d3208eed84b71b61c458f2992b08b5db259da1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 29 Aug 2021 13:19:13 -0500 Subject: [PATCH 571/608] Skip goodbye packets for addresses when there is another service registered with the same name (#968) --- tests/services/test_registry.py | 23 ++++++ tests/test_asyncio.py | 134 ++++++++++++++++++++++++++++++++ zeroconf/_core.py | 39 ++++++++-- 3 files changed, 188 insertions(+), 8 deletions(-) diff --git a/tests/services/test_registry.py b/tests/services/test_registry.py index 87c048d5..3c105cbb 100644 --- a/tests/services/test_registry.py +++ b/tests/services/test_registry.py @@ -28,6 +28,29 @@ def test_only_register_once(self): registry.async_remove(info) registry.async_add(info) + def test_register_same_server(self): + type_ = "_test-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + registration_name = "%s.%s" % (name, type_) + registration_name2 = "%s.%s" % (name2, type_) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, registration_name, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + info2 = ServiceInfo( + type_, registration_name2, 80, 0, 0, desc, "same.local.", addresses=[socket.inet_aton("10.0.1.2")] + ) + registry = r.ServiceRegistry() + registry.async_add(info) + registry.async_add(info2) + assert registry.async_get_infos_server("same.local.") == [info, info2] + registry.async_remove(info) + assert registry.async_get_infos_server("same.local.") == [info2] + registry.async_remove(info2) + assert registry.async_get_infos_server("same.local.") == [] + def test_unregister_multiple_times(self): """Verify we can unregister a service multiple times. diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index e6da20a6..ea80d6f5 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -173,6 +173,140 @@ def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: ] +@pytest.mark.asyncio +async def test_async_service_registration_same_server_different_ports() -> None: + """Test registering services with the same server with different srv records.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 81, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + task = await aiozc.async_unregister_service(info) + await task + entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.") + assert len(entries) == 1 + assert info2.dns_service() in entries + await aiozc.async_close() + assert calls == [ + ('add', type_, registration_name), + ('add', type_, registration_name2), + ('remove', type_, registration_name), + ('remove', type_, registration_name2), + ] + + +@pytest.mark.asyncio +async def test_async_service_registration_same_server_same_ports() -> None: + """Test registering services with the same server with the exact same srv record.""" + aiozc = AsyncZeroconf(interfaces=['127.0.0.1']) + type_ = "_test1-srvc-type._tcp.local." + name = "xxxyyy" + name2 = "xxxyyy2" + + registration_name = f"{name}.{type_}" + registration_name2 = f"{name2}.{type_}" + + calls = [] + + class MyListener(ServiceListener): + def add_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("add", type, name)) + + def remove_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("remove", type, name)) + + def update_service(self, zeroconf: Zeroconf, type: str, name: str) -> None: + calls.append(("update", type, name)) + + listener = MyListener() + + aiozc.zeroconf.add_service_listener(type_, listener) + + desc = {'path': '/~paulsm/'} + info = ServiceInfo( + type_, + registration_name, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + info2 = ServiceInfo( + type_, + registration_name2, + 80, + 0, + 0, + desc, + "ash-2.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + tasks = [] + tasks.append(await aiozc.async_register_service(info)) + tasks.append(await aiozc.async_register_service(info2)) + await asyncio.gather(*tasks) + + task = await aiozc.async_unregister_service(info) + await task + entries = aiozc.zeroconf.cache.async_entries_with_server("ash-2.local.") + assert len(entries) == 1 + assert info2.dns_service() in entries + await aiozc.async_close() + assert calls == [ + ('add', type_, registration_name), + ('add', type_, registration_name2), + ('remove', type_, registration_name), + ('remove', type_, registration_name2), + ] + + @pytest.mark.asyncio async def test_async_service_registration_name_conflict() -> None: """Test registering services throws on name conflict.""" diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 96b1a790..c9c4c5ec 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -571,17 +571,28 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: self.registry.async_update(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) - async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None: + async def _async_broadcast_service( + self, + info: ServiceInfo, + interval: int, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> None: """Send a broadcasts to announce a service at intervals.""" for i in range(_REGISTER_BROADCASTS): if i != 0: await asyncio.sleep(millis_to_seconds(interval)) - self.async_send(self.generate_service_broadcast(info, ttl)) + self.async_send(self.generate_service_broadcast(info, ttl, broadcast_addresses)) - def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing: + def generate_service_broadcast( + self, + info: ServiceInfo, + ttl: Optional[int], + broadcast_addresses: bool = True, + ) -> DNSOutgoing: """Generate a broadcast to announce a service.""" out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA) - self._add_broadcast_answer(out, info, ttl) + self._add_broadcast_answer(out, info, ttl, broadcast_addresses) return out def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use @@ -600,7 +611,11 @@ def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: d return out def _add_broadcast_answer( # pylint: disable=no-self-use - self, out: DNSOutgoing, info: ServiceInfo, override_ttl: Optional[int] + self, + out: DNSOutgoing, + info: ServiceInfo, + override_ttl: Optional[int], + broadcast_addresses: bool = True, ) -> None: """Add answers to broadcast a service.""" now = current_time_millis() @@ -609,8 +624,9 @@ def _add_broadcast_answer( # pylint: disable=no-self-use out.add_answer_at_time(info.dns_pointer(override_ttl=other_ttl, created=now), 0) out.add_answer_at_time(info.dns_service(override_ttl=host_ttl, created=now), 0) out.add_answer_at_time(info.dns_text(override_ttl=other_ttl, created=now), 0) - for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): - out.add_answer_at_time(dns_address, 0) + if broadcast_addresses: + for dns_address in info.dns_addresses(override_ttl=host_ttl, created=now): + out.add_answer_at_time(dns_address, 0) def unregister_service(self, info: ServiceInfo) -> None: """Unregister a service.""" @@ -622,7 +638,14 @@ def unregister_service(self, info: ServiceInfo) -> None: async def async_unregister_service(self, info: ServiceInfo) -> Awaitable: """Unregister a service.""" self.registry.async_remove(info) - return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0)) + # If another server uses the same addresses, we do not want to send + # goodbye packets for the address records + + entries = self.registry.async_get_infos_server(info.server) + broadcast_addresses = not bool(entries) + return asyncio.ensure_future( + self._async_broadcast_service(info, _UNREGISTER_TIME, 0, broadcast_addresses) + ) def generate_unregister_all_services(self) -> Optional[DNSOutgoing]: """Generate a DNSOutgoing goodbye for all services and remove them from the registry.""" From d5043337de39a11b2b241e9247a34c41c0c7c2bc Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 29 Aug 2021 13:29:19 -0500 Subject: [PATCH 572/608] Update changelog for 0.36.1 (#970) --- README.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.rst b/README.rst index a766e58a..f2c156a2 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,20 @@ See examples directory for more. Changelog ========= +0.36.1 +====== + +* Skip goodbye packets for addresses when there is another service registered with the same name (#968) @bdraco + + If a ServiceInfo that used the same server name as another ServiceInfo + was unregistered, goodbye packets would be sent for the addresses and + would cause the other service to be seen as offline. +* Fixed equality and hash for dns records with the unique bit (#969) @bdraco + + These records should have the same hash and equality since + the unique bit (cache flush bit) is not considered when adding or removing + the records from the cache. + 0.36.0 ====== From e8d84017b750ab5f159abc7225f9922d84a8f9fd Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 29 Aug 2021 13:41:54 -0500 Subject: [PATCH 573/608] =?UTF-8?q?Bump=20version:=200.36.0=20=E2=86=92=20?= =?UTF-8?q?0.36.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 716f4661..92cc1627 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.0 +current_version = 0.36.1 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 22e0af99..5fe885f7 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.0' +__version__ = '0.36.1' __license__ = 'LGPL' From 768a23c656e3f091ecbecbb6b380b5becbbf9674 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 29 Aug 2021 17:40:12 -0500 Subject: [PATCH 574/608] Add support for writing NSEC records (#971) --- tests/test_protocol.py | 29 +++++++++++++++++++++++++++++ zeroconf/_dns.py | 18 +++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8c2f92c4..6ad3303b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -54,6 +54,35 @@ def test_parse_own_packet_question(self): generated.add_question(r.DNSQuestion("testname.local.", const._TYPE_SRV, const._CLASS_IN)) r.DNSIncoming(generated.packets()[0]) + def test_parse_own_packet_nsec(self): + answer = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV], + ) + + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time(answer, 0) + parsed = r.DNSIncoming(generated.packets()[0]) + assert answer in parsed.answers + + # Types > 255 should be ignored + answer_invalid_types = r.DNSNsec( + 'eufy HomeBase2-2464._hap._tcp.local.', + const._TYPE_NSEC, + const._CLASS_IN | const._CLASS_UNIQUE, + const._DNS_OTHER_TTL, + 'eufy HomeBase2-2464._hap._tcp.local.', + [const._TYPE_TXT, const._TYPE_SRV, 1000], + ) + generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) + generated.add_answer_at_time(answer_invalid_types, 0) + parsed = r.DNSIncoming(generated.packets()[0]) + assert answer in parsed.answers + def test_parse_own_packet_response(self): generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE) generated.add_answer_at_time( diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 0d09a421..bb447b2f 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -458,9 +458,25 @@ def __init__( ) -> None: super().__init__(name, type_, class_, ttl, created) self.next_name = next_name - self.rdtypes = rdtypes + self.rdtypes = sorted(rdtypes) self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes)) + def write(self, out: 'DNSOutgoing') -> None: + """Used in constructing an outgoing packet.""" + bitmap = bytearray(b'\0' * 32) + for rdtype in self.rdtypes: + if rdtype > 255: # mDNS only supports window 0 + continue + offset = rdtype % 256 + byte = offset // 8 + total_octets = byte + 1 + bitmap[byte] |= 0x80 >> (offset % 8) + out_bytes = bytes(bitmap[0:total_octets]) + out.write_name(self.next_name) + out.write_short(0) + out.write_short(len(out_bytes)) + out.write_string(out_bytes) + def __eq__(self, other: Any) -> bool: """Tests equality on cpu and os""" return ( From 7a20fd3bc8dc0a703619ca9413faf674b3d7a111 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Aug 2021 09:50:16 -0500 Subject: [PATCH 575/608] Include NSEC records for non-existant types when responding with addresses (#972) Implements datatracker.ietf.org/doc/html/rfc6762#section-6.2 --- tests/test_handlers.py | 49 +++++++++++++++++++++++++++++++------- zeroconf/_handlers.py | 54 +++++++++++++++++++++++++++++++++--------- 2 files changed, 83 insertions(+), 20 deletions(-) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index a621f037..44ee1d5a 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -108,8 +108,9 @@ def _process_outgoing_packet(out): _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) # The additonals should all be suppresed since they are all in the answers section + # There will be one NSEC additional to indicate the lack of AAAA record # - assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -143,7 +144,9 @@ def _process_outgoing_packet(out): [r.DNSIncoming(packet) for packet in query.packets()], False ) _process_outgoing_packet(construct_outgoing_multicast_answers(question_answers.mcast_aggregate)) - assert nbr_answers == 4 and nbr_additionals == 0 and nbr_authorities == 0 + + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 4 and nbr_additionals == 1 and nbr_authorities == 0 nbr_answers = nbr_additionals = nbr_authorities = 0 # unregister @@ -271,7 +274,9 @@ def test_ptr_optimization(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 + assert nbr_answers == 1 and nbr_additionals == 4 + # There will be one NSEC additional to indicate the lack of AAAA record + assert has_srv and has_txt and has_a # unregister @@ -406,7 +411,7 @@ def test_unicast_response(): [r.DNSIncoming(packet) for packet in query.packets()], True ) for answers in (question_answers.ucast, question_answers.mcast_aggregate): - has_srv = has_txt = has_a = False + has_srv = has_txt = has_a = has_aaaa = has_nsec = False nbr_additionals = 0 nbr_answers = len(answers) additionals = set().union(*answers.values()) @@ -418,8 +423,14 @@ def test_unicast_response(): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 - assert has_srv and has_txt and has_a + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + # There will be one NSEC additional to indicate the lack of AAAA record + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa # unregister zc.registry.async_remove(info) @@ -497,7 +508,7 @@ def test_qu_response(): zc.register_service(info) def _validate_complete_response(answers): - has_srv = has_txt = has_a = False + has_srv = has_txt = has_a = has_aaaa = has_nsec = False nbr_answers = len(answers.keys()) additionals = set().union(*answers.values()) nbr_additionals = len(additionals) @@ -509,8 +520,13 @@ def _validate_complete_response(answers): has_txt = True elif answer.type == const._TYPE_A: has_a = True - assert nbr_answers == 1 and nbr_additionals == 3 - assert has_srv and has_txt and has_a + elif answer.type == const._TYPE_AAAA: + has_aaaa = True + elif answer.type == const._TYPE_NSEC: + has_nsec = True + assert nbr_answers == 1 and nbr_additionals == 4 + assert has_srv and has_txt and has_a and has_nsec + assert not has_aaaa # With QU should respond to only unicast when the answer has been recently multicast query = r.DNSOutgoing(const._FLAGS_QR_QUERY) @@ -635,6 +651,21 @@ def test_known_answer_supression(): assert not question_answers.mcast_aggregate assert not question_answers.mcast_aggregate_last_second + # Test NSEC record returned when there is no AAAA record and we expectly ask + generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) + question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN) + generated.add_question(question) + for dns_address in info.dns_addresses(): + generated.add_answer_at_time(dns_address, now) + packets = generated.packets() + question_answers = zc.query_handler.async_response([r.DNSIncoming(packet) for packet in packets], False) + assert not question_answers.ucast + expected_nsec_record: r.DNSNsec = list(question_answers.mcast_now)[0] + assert const._TYPE_A not in expected_nsec_record.rdtypes + assert const._TYPE_AAAA in expected_nsec_record.rdtypes + assert not question_answers.mcast_aggregate + assert not question_answers.mcast_aggregate_last_second + # Test SRV supression generated = r.DNSOutgoing(const._FLAGS_QR_QUERY) question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 06ed54cd..76ba6cc3 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -26,15 +26,17 @@ from typing import Dict, Iterable, List, NamedTuple, Optional, Set, TYPE_CHECKING, Tuple, Union, cast from ._cache import DNSCache, _UniqueRecordsType -from ._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord +from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._history import QuestionHistory from ._logger import log from ._protocol import DNSIncoming, DNSOutgoing +from ._services.info import ServiceInfo from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener from ._utils.time import current_time_millis, millis_to_seconds from .const import ( _CLASS_IN, + _CLASS_UNIQUE, _DNS_OTHER_TTL, _DNS_PTR_MIN_TTL, _FLAGS_AA, @@ -44,6 +46,7 @@ _TYPE_A, _TYPE_AAAA, _TYPE_ANY, + _TYPE_NSEC, _TYPE_PTR, _TYPE_SRV, _TYPE_TXT, @@ -56,7 +59,8 @@ _AnswerWithAdditionalsType = Dict[DNSRecord, Set[DNSRecord]] _MULTICAST_DELAY_RANDOM_INTERVAL = (20, 120) -_RESPOND_IMMEDIATE_TYPES = {_TYPE_SRV, _TYPE_A, _TYPE_AAAA} +_ADDRESS_RECORD_TYPES = {_TYPE_A, _TYPE_AAAA} +_RESPOND_IMMEDIATE_TYPES = {_TYPE_NSEC, _TYPE_SRV, *_ADDRESS_RECORD_TYPES} class QuestionAnswers(NamedTuple): @@ -78,6 +82,15 @@ def _message_is_probe(msg: DNSIncoming) -> bool: return msg.num_authorities > 0 +def construct_nsec_record(name: str, types: List[int], now: float) -> DNSNsec: + """Construct an NSEC record for name and a list of dns types. + + This function should only be used for SRV/A/AAAA records + which have a TTL of _DNS_OTHER_TTL + """ + return DNSNsec(name, _TYPE_NSEC, _CLASS_IN | _CLASS_UNIQUE, _DNS_OTHER_TTL, name, types, created=now) + + def construct_outgoing_multicast_answers(answers: _AnswerWithAdditionalsType) -> DNSOutgoing: """Add answers and additionals to a DNSOutgoing.""" out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA, multicast=True) @@ -244,12 +257,23 @@ def _add_pointer_answers( # Add recommended additional answers according to # https://tools.ietf.org/html/rfc6763#section-12.1. dns_pointer = service.dns_pointer(created=now) - if not known_answers.suppresses(dns_pointer): - answer_set[dns_pointer] = { - service.dns_service(created=now), - service.dns_text(created=now), - *service.dns_addresses(created=now), - } + if known_answers.suppresses(dns_pointer): + continue + additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)} + additionals |= self._get_address_and_nsec_records(service, now) + answer_set[dns_pointer] = additionals + + def _get_address_and_nsec_records(self, service: ServiceInfo, now: float) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + seen_types: Set[int] = set() + records: Set[DNSRecord] = set() + for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) + records.add(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if missing_types: + records.add(construct_nsec_record(service.server, list(missing_types), now)) + return records def _add_address_answers( self, @@ -263,13 +287,21 @@ def _add_address_answers( for service in self.registry.async_get_infos_server(name): answers: List[DNSAddress] = [] additionals: Set[DNSRecord] = set() + seen_types: Set[int] = set() for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) if dns_address.type != type_: additionals.add(dns_address) elif not known_answers.suppresses(dns_address): answers.append(dns_address) - for answer in answers: - answer_set[answer] = additionals + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if answers: + if missing_types: + additionals.add(construct_nsec_record(service.server, list(missing_types), now)) + for answer in answers: + answer_set[answer] = additionals + elif type_ in missing_types: + answer_set[construct_nsec_record(service.server, list(missing_types), now)] = set() def _answer_question( self, @@ -299,7 +331,7 @@ def _answer_question( # https://tools.ietf.org/html/rfc6763#section-12.2. dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = set(service.dns_addresses(created=now)) + answer_set[dns_service] = self._get_address_and_nsec_records(service, now) if type_ in (_TYPE_TXT, _TYPE_ANY): dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text): From b4efa33b4ef6d5292d8d477da4258d99d22c4e84 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Aug 2021 10:04:09 -0500 Subject: [PATCH 576/608] Update changelog for 0.36.2 (#973) --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index f2c156a2..2b7cbd5e 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,12 @@ See examples directory for more. Changelog ========= +0.36.2 +====== + +* Include NSEC records for non-existent types when responding with addresses (#972) (#971) @bdraco + Implements RFC6762 sec 6.2 (http://datatracker.ietf.org/doc/html/rfc6762#section-6.2) + 0.36.1 ====== From 5f52438f4c0851bb1a3b78575c0c28e0b6ce561d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 30 Aug 2021 10:04:19 -0500 Subject: [PATCH 577/608] =?UTF-8?q?Bump=20version:=200.36.1=20=E2=86=92=20?= =?UTF-8?q?0.36.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 92cc1627..2d2433b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.1 +current_version = 0.36.2 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5fe885f7..f0fce54a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.1' +__version__ = '0.36.2' __license__ = 'LGPL' From 78f9cd5123d0e3c582aba05bd61388419d4dc01e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 14 Sep 2021 10:11:49 -1000 Subject: [PATCH 578/608] Reduce DNSIncoming parsing overhead (#975) - Parsing incoming packets is the most expensive operation zeroconf performs on networks with high mDNS volume --- zeroconf/__init__.py | 1 - zeroconf/_protocol.py | 26 +++++++++++--------------- zeroconf/_services/browser.py | 6 ++---- zeroconf/_services/info.py | 3 +-- zeroconf/_utils/struct.py | 25 ------------------------- 5 files changed, 14 insertions(+), 47 deletions(-) delete mode 100644 zeroconf/_utils/struct.py diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index f0fce54a..347bbdd3 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -75,7 +75,6 @@ IPVersion, get_all_addresses, ) -from ._utils.struct import int2byte # noqa # import needed for backwards compat from ._utils.time import current_time_millis, millis_to_seconds # noqa # import needed for backwards compat __author__ = 'Paul Scott-Murphy, William McBrine' diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index b87e67d8..15c7533f 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -28,7 +28,6 @@ from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText from ._exceptions import IncomingDecodeError, NamePartTooLongException from ._logger import QuietLogger, log -from ._utils.struct import int2byte from ._utils.time import current_time_millis from .const import ( _CLASS_UNIQUE, @@ -64,6 +63,8 @@ class DNSMessage: """A base class for DNS messages.""" + __slots__ = ('flags',) + def __init__(self, flags: int) -> None: """Construct a DNS message.""" self.flags = flags @@ -128,11 +129,9 @@ def __repr__(self) -> str: ] ) - def unpack(self, format_: bytes) -> tuple: - length = struct.calcsize(format_) - info = struct.unpack(format_, self.data[self.offset : self.offset + length]) + def unpack(self, format_: bytes, length: int) -> tuple: self.offset += length - return info + return struct.unpack(format_, self.data[self.offset - length : self.offset]) def read_header(self) -> None: """Reads header portion of packet""" @@ -143,16 +142,14 @@ def read_header(self) -> None: self.num_answers, self.num_authorities, self.num_additionals, - ) = self.unpack(b'!6H') + ) = self.unpack(b'!6H', 12) def read_questions(self) -> None: """Reads questions section of packet""" for _ in range(self.num_questions): name = self.read_name() - type_, class_ = self.unpack(b'!HH') - - question = DNSQuestion(name, type_, class_) - self.questions.append(question) + type_, class_ = self.unpack(b'!HH', 4) + self.questions.append(DNSQuestion(name, type_, class_)) def read_character_string(self) -> bytes: """Reads a character string from the packet""" @@ -168,7 +165,7 @@ def read_string(self, length: int) -> bytes: def read_unsigned_short(self) -> int: """Reads an unsigned short from the packet""" - return cast(int, self.unpack(b'!H')[0]) + return cast(int, self.unpack(b'!H', 2)[0]) def read_others(self) -> None: """Reads the answers, authorities and additionals section of the @@ -176,7 +173,7 @@ def read_others(self) -> None: n = self.num_answers + self.num_authorities + self.num_additionals for _ in range(n): domain = self.read_name() - type_, class_, ttl, length = self.unpack(b'!HHiH') + type_, class_, ttl, length = self.unpack(b'!HHiH', 10) end = self.offset + length rec = None try: @@ -266,8 +263,7 @@ def read_name(self) -> str: labels: List[str] = [] self.seen_pointers.clear() self.offset = self._decode_labels_at_offset(self.offset, labels) - labels.append("") - name = ".".join(labels) + name = ".".join(labels) + "." if len(name) > MAX_NAME_LENGTH: raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") return name @@ -440,7 +436,7 @@ def _pack(self, format_: Union[bytes, str], value: Any) -> None: def _write_byte(self, value: int) -> None: """Writes a single byte to the packet""" - self._pack(b'!c', int2byte(value)) + self._pack(b'!c', bytes((value,))) def _insert_short_at_start(self, value: int) -> None: """Inserts an unsigned short at the start of the packet""" diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index aadbd7ac..d47e42e9 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -352,21 +352,19 @@ def _async_process_record_update( self, now: float, record: DNSRecord, old_record: Optional[DNSRecord] ) -> None: """Process a single record update from a batch of updates.""" - expired = record.is_expired(now) - if isinstance(record, DNSPointer): if record.name not in self.types: return if old_record is None: self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias) - elif expired: + elif record.is_expired(now): self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias) else: self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)) return # If its expired or already exists in the cache it cannot be updated. - if expired or old_record: + if old_record or record.is_expired(now): return if isinstance(record, DNSAddress): diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 33c0488a..7aaea1b6 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -36,7 +36,6 @@ _encode_address, _is_v6_address, ) -from .._utils.struct import int2byte from .._utils.time import current_time_millis from ..const import ( _CLASS_IN, @@ -239,7 +238,7 @@ def _set_properties(self, properties: Dict) -> None: record += b'=' + value list_.append(record) for item in list_: - result = b''.join((result, int2byte(len(item)), item)) + result = b''.join((result, bytes((len(item),)), item)) self.text = result def _set_text(self, text: bytes) -> None: diff --git a/zeroconf/_utils/struct.py b/zeroconf/_utils/struct.py deleted file mode 100644 index 6ec99988..00000000 --- a/zeroconf/_utils/struct.py +++ /dev/null @@ -1,25 +0,0 @@ -""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine - Copyright 2003 Paul Scott-Murphy, 2014 William McBrine - - This module provides a framework for the use of DNS Service Discovery - using IP multicast. - - This library is free software; you can redistribute it and/or - modify it under the terms of the GNU Lesser General Public - License as published by the Free Software Foundation; either - version 2.1 of the License, or (at your option) any later version. - - This library is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - Lesser General Public License for more details. - - You should have received a copy of the GNU Lesser General Public - License along with this library; if not, write to the Free Software - Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 - USA -""" - -import struct - -int2byte = struct.Struct(">B").pack From 84f16bff6df41f1907e060e7bd4ce24d173d51c4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 14 Sep 2021 10:19:44 -1000 Subject: [PATCH 579/608] Update changelog for 0.36.3 (#977) --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index 2b7cbd5e..7c1c2863 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,11 @@ See examples directory for more. Changelog ========= +0.36.3 +====== + +* Improved performance of parsing incoming packets (#975) @bdraco + 0.36.2 ====== From 769b3973835ebc6f5a34e236a01cb2cd935e81de Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 14 Sep 2021 15:20:09 -0500 Subject: [PATCH 580/608] =?UTF-8?q?Bump=20version:=200.36.2=20=E2=86=92=20?= =?UTF-8?q?0.36.3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 2d2433b8..7727170e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.2 +current_version = 0.36.3 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 347bbdd3..c037c420 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.2' +__version__ = '0.36.3' __license__ = 'LGPL' From f1d6fc3f60e685ff63b1a1cb820cfc3ca5268fcb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 10:25:54 -1000 Subject: [PATCH 581/608] Reduce name compression overhead and complexity (#978) --- zeroconf/_protocol.py | 65 +++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 15c7533f..713d5d91 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -430,13 +430,13 @@ def add_question_or_all_cache( for cached_entry in cached_entries: self.add_answer_at_time(cached_entry, now) - def _pack(self, format_: Union[bytes, str], value: Any) -> None: + def _pack(self, format_: Union[bytes, str], size: int, value: Any) -> None: self.data.append(struct.pack(format_, value)) - self.size += struct.calcsize(format_) + self.size += size def _write_byte(self, value: int) -> None: """Writes a single byte to the packet""" - self._pack(b'!c', bytes((value,))) + self._pack(b'!c', 1, bytes((value,))) def _insert_short_at_start(self, value: int) -> None: """Inserts an unsigned short at the start of the packet""" @@ -448,11 +448,11 @@ def _replace_short(self, index: int, value: int) -> None: def write_short(self, value: int) -> None: """Writes an unsigned short to the packet""" - self._pack(b'!H', value) + self._pack(b'!H', 2, value) def _write_int(self, value: Union[float, int]) -> None: """Writes an unsigned integer to the packet""" - self._pack(b'!I', int(value)) + self._pack(b'!I', 4, int(value)) def write_string(self, value: bytes) -> None: """Writes a string to the packet""" @@ -491,38 +491,29 @@ def write_name(self, name: str) -> None: """ # split name into each label - parts = name.split('.') - if not parts[-1]: - parts.pop() - - # construct each suffix - name_suffices = ['.'.join(parts[i:]) for i in range(len(parts))] - - # look for an existing name or suffix - for count, sub_name in enumerate(name_suffices): - if sub_name in self.names: - break - else: - count = len(name_suffices) - - # note the new names we are saving into the packet - name_length = len(name.encode('utf-8')) - for suffix in name_suffices[:count]: - self.names[suffix] = self.size + name_length - len(suffix.encode('utf-8')) - 1 - - # write the new names out. - for part in parts[:count]: - self._write_utf(part) - - # if we wrote part of the name, create a pointer to the rest - if count != len(name_suffices): - # Found substring in packet, create pointer - index = self.names[name_suffices[count]] - self._write_byte((index >> 8) | 0xC0) - self._write_byte(index & 0xFF) - else: - # this is the end of a name - self._write_byte(0) + name_length = None + if name.endswith('.'): + name = name[: len(name) - 1] + labels = name.split('.') + # Write each new label or a pointer to the existing + # on in the packet + start_size = self.size + for count in range(len(labels)): + label = name if count == 0 else '.'.join(labels[count:]) + index = self.names.get(label) + if index: + # If part of the name already exists in the packet, + # create a pointer to it + self._write_byte((index >> 8) | 0xC0) + self._write_byte(index & 0xFF) + return + if name_length is None: + name_length = len(name.encode('utf-8')) + self.names[label] = start_size + name_length - len(label.encode('utf-8')) + self._write_utf(labels[count]) + + # this is the end of a name + self._write_byte(0) def _write_question(self, question: DNSQuestion) -> bool: """Writes a question to the packet""" From d9ea9189def07531d126e01c7397b2596d9a8695 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 10:36:18 -1000 Subject: [PATCH 582/608] Force CI cache clear (#982) --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 378970c4..3686d617 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +# version: 1.0 + .PHONY: all virtualenv MAX_LINE_LENGTH=110 PYTHON_IMPLEMENTATION:=$(shell python -c "import sys;import platform;sys.stdout.write(platform.python_implementation())") From acf6457b3c6742c92e9112b0a39a387b33cea4db Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 10:50:46 -1000 Subject: [PATCH 583/608] Reduce duplicate code to write records (#979) --- zeroconf/_protocol.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 713d5d91..43af3991 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -22,7 +22,7 @@ import enum import struct -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union, cast from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText @@ -586,21 +586,13 @@ def _write_answers_from_offset(self, answer_offset: int) -> int: answers_written += 1 return answers_written - def _write_authorities_from_offset(self, authority_offset: int) -> int: - authorities_written = 0 - for authority in self.authorities[authority_offset:]: - if not self._write_record(authority, 0): + def _write_records_from_offset(self, records: Sequence[DNSRecord], offset: int) -> int: + records_written = 0 + for record in records[offset:]: + if not self._write_record(record, 0): break - authorities_written += 1 - return authorities_written - - def _write_additionals_from_offset(self, additional_offset: int) -> int: - additionals_written = 0 - for additional in self.additionals[additional_offset:]: - if not self._write_record(additional, 0): - break - additionals_written += 1 - return additionals_written + records_written += 1 + return records_written def _has_more_to_add( self, questions_offset: int, answer_offset: int, authority_offset: int, additional_offset: int @@ -654,8 +646,8 @@ def packets(self) -> List[bytes]: questions_written = self._write_questions_from_offset(questions_offset) answers_written = self._write_answers_from_offset(answer_offset) - authorities_written = self._write_authorities_from_offset(authority_offset) - additionals_written = self._write_additionals_from_offset(additional_offset) + authorities_written = self._write_records_from_offset(self.authorities, authority_offset) + additionals_written = self._write_records_from_offset(self.additionals, additional_offset) self._insert_short_at_start(additionals_written) self._insert_short_at_start(authorities_written) From bc64d63ef73e643e71634957fd79e6f6597373d4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 11:02:58 -1000 Subject: [PATCH 584/608] Remove flake8 requirement restriction as its no longer needed (#981) --- requirements-dev.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 3035d59d..e7483666 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,8 +3,7 @@ black;implementation_name=="cpython" bump2version coveralls coverage -# Version restricted because of https://github.com/PyCQA/pycodestyle/issues/741 - is fixed -flake8>=3.6.0 +flake8 flake8-import-order ifaddr mypy;implementation_name=="cpython" From 05c4329d7647c381783ead086c2ed4f3b6b44262 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 13:22:13 -1000 Subject: [PATCH 585/608] Collapse _GLOBAL_DONE into done (#984) --- zeroconf/_core.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index c9c4c5ec..f4161877 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -400,8 +400,7 @@ def __init__( if ip_version is None: ip_version = autodetect_ip_version(interfaces) - # hook for threads - self._GLOBAL_DONE = False + self.done = False if apple_p2p and sys.platform != 'darwin': raise RuntimeError('Option `apple_p2p` is not supported on non-Apple platforms.') @@ -456,10 +455,6 @@ async def async_wait_for_start(self) -> None: """Wait for start up.""" await self.engine.async_wait_for_start() - @property - def done(self) -> bool: - return self._GLOBAL_DONE - @property def listeners(self) -> List[RecordUpdateListener]: return self.record_manager.listeners @@ -815,7 +810,7 @@ def async_send( transport: Optional[asyncio.DatagramTransport] = None, ) -> None: """Sends an outgoing packet.""" - if self._GLOBAL_DONE: + if self.done: return # If no transport is specified, we send to all the ones @@ -866,10 +861,10 @@ def _async_send_transport( def _close(self) -> None: """Set global done and remove all service listeners.""" - if self._GLOBAL_DONE: + if self.done: return self.remove_all_service_listeners() - self._GLOBAL_DONE = True + self.done = True def _shutdown_threads(self) -> None: """Shutdown any threads.""" From 88b987551cb98757c2df2540ba390f320d46fa7b Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 13:23:43 -1000 Subject: [PATCH 586/608] Defer decoding known answers until needed (#983) --- zeroconf/_handlers.py | 2 +- zeroconf/_protocol.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 76ba6cc3..22848409 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -348,7 +348,7 @@ def async_response( # pylint: disable=unused-argument threadsafe. """ known_answers = DNSRRSet( - itertools.chain(*(msg.answers for msg in msgs if not _message_is_probe(msg))) + itertools.chain.from_iterable(msg.answers for msg in msgs if not _message_is_probe(msg)) ) query_res = _QueryResponse(self.cache, msgs) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index 43af3991..daff1ca0 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -22,7 +22,7 @@ import enum import struct -from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union, cast from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText @@ -96,23 +96,39 @@ def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[fl self.name_cache: Dict[int, List[str]] = {} self.seen_pointers: Set[int] = set() self.questions: List[DNSQuestion] = [] - self.answers: List[DNSRecord] = [] + self._answers: List[DNSRecord] = [] self.id = 0 self.num_questions = 0 self.num_answers = 0 self.num_authorities = 0 self.num_additionals = 0 self.valid = False + self._read_others = False self.now = now or current_time_millis() self.scope_id = scope_id + self._parse_data(self._initial_parse) - try: - self.read_header() - self.read_questions() + def _initial_parse(self) -> None: + """Parse the data needed to initalize the packet object.""" + self.read_header() + self.read_questions() + if not self.num_questions: self.read_others() - self.valid = True + self.valid = True + + def _parse_data(self, parser_call: Callable) -> None: + """Parse part of the packet and catch exceptions.""" + try: + parser_call() except DECODE_EXCEPTIONS: - self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, data) + self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, self.data) + + @property + def answers(self) -> List[DNSRecord]: + """Answers in the packet.""" + if not self._read_others: + self._parse_data(self.read_others) + return self._answers def __repr__(self) -> str: return '' % ', '.join( @@ -170,6 +186,7 @@ def read_unsigned_short(self) -> int: def read_others(self) -> None: """Reads the answers, authorities and additionals section of the packet""" + self._read_others = True n = self.num_answers + self.num_authorities + self.num_additionals for _ in range(n): domain = self.read_name() @@ -192,7 +209,7 @@ def read_others(self) -> None: exc_info=True, ) if rec is not None: - self.answers.append(rec) + self._answers.append(rec) def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: """Read known records types and skip unknown ones.""" From f4d4164989931adbac0e5907b7bf276da1d0d7d7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 13:50:07 -1000 Subject: [PATCH 587/608] Update changelog for 0.36.4 (#985) --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index 7c1c2863..91192344 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,12 @@ See examples directory for more. Changelog ========= +0.36.4 +====== + +* Improved performance of constructing outgoing packets (#978) (#979) @bdraco +* Defered parsing of incoming packets when it can be avoided (#983) @bdraco + 0.36.3 ====== From a23f6d2cc40ea696410c3c31b73760065c36f0bf Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 18:50:24 -0500 Subject: [PATCH 588/608] =?UTF-8?q?Bump=20version:=200.36.3=20=E2=86=92=20?= =?UTF-8?q?0.36.4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7727170e..b467a89c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.3 +current_version = 0.36.4 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index c037c420..e8c03c64 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.3' +__version__ = '0.36.4' __license__ = 'LGPL' From 43985380b9e995d9790d71486aed258326ad86e4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 16 Sep 2021 13:53:13 -1000 Subject: [PATCH 589/608] Fix typo in changelog (#986) --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 91192344..3c6f36c2 100644 --- a/README.rst +++ b/README.rst @@ -142,7 +142,7 @@ Changelog ====== * Improved performance of constructing outgoing packets (#978) (#979) @bdraco -* Defered parsing of incoming packets when it can be avoided (#983) @bdraco +* Deferred parsing of incoming packets when it can be avoided (#983) @bdraco 0.36.3 ====== From f4665fc67cd762c4ab66271a550d75640d3bffca Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 18 Sep 2021 16:02:52 -1000 Subject: [PATCH 590/608] Reduce dns protocol attributes and add slots (#987) --- zeroconf/_protocol.py | 53 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol.py index daff1ca0..b6ef7bcb 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol.py @@ -87,6 +87,23 @@ class DNSIncoming(DNSMessage, QuietLogger): """Object representation of an incoming DNS packet""" + __slots__ = ( + 'offset', + 'data', + 'data_len', + 'name_cache', + 'questions', + '_answers', + 'id', + 'num_questions', + 'num_answers', + 'num_authorities', + 'num_additionals', + 'valid', + 'now', + 'scope_id', + ) + def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None: """Constructor from string holding bytes of packet""" super().__init__(0) @@ -94,7 +111,6 @@ def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[fl self.data = data self.data_len = len(data) self.name_cache: Dict[int, List[str]] = {} - self.seen_pointers: Set[int] = set() self.questions: List[DNSQuestion] = [] self._answers: List[DNSRecord] = [] self.id = 0 @@ -162,10 +178,9 @@ def read_header(self) -> None: def read_questions(self) -> None: """Reads questions section of packet""" - for _ in range(self.num_questions): - name = self.read_name() - type_, class_ = self.unpack(b'!HH', 4) - self.questions.append(DNSQuestion(name, type_, class_)) + self.questions = [ + DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions) + ] def read_character_string(self) -> bytes: """Reads a character string from the packet""" @@ -278,14 +293,14 @@ def read_bitmap(self, end: int) -> List[int]: def read_name(self) -> str: """Reads a domain name from the packet.""" labels: List[str] = [] - self.seen_pointers.clear() - self.offset = self._decode_labels_at_offset(self.offset, labels) + seen_pointers: Set[int] = set() + self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers) name = ".".join(labels) + "." if len(name) > MAX_NAME_LENGTH: raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") return name - def _decode_labels_at_offset(self, off: int, labels: List[str]) -> int: + def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. while off < self.data_len: length = self.data[off] @@ -307,12 +322,12 @@ def _decode_labels_at_offset(self, off: int, labels: List[str]) -> int: raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet") if link == off: raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself") - if link in self.seen_pointers: + if link in seen_pointers: raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again") - self.seen_pointers.add(link) + seen_pointers.add(link) linked_labels = self.name_cache.get(link, []) if not linked_labels: - self._decode_labels_at_offset(link, linked_labels) + self._decode_labels_at_offset(link, linked_labels, seen_pointers) self.name_cache[link] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: @@ -326,6 +341,22 @@ class DNSOutgoing(DNSMessage): """Object representation of an outgoing packet""" + __slots__ = ( + 'finished', + 'id', + 'multicast', + 'packets_data', + 'names', + 'data', + 'size', + 'allow_long', + 'state', + 'questions', + 'answers', + 'authorities', + 'additionals', + ) + def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None: super().__init__(flags) self.finished = False From 87b6a32fb77d9bdcea9d2d7ffba189abc5371b50 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 18 Sep 2021 16:39:30 -1000 Subject: [PATCH 591/608] Seperate zeroconf._protocol into an incoming and outgoing modules (#988) --- setup.py | 2 +- tests/test_core.py | 7 +- zeroconf/__init__.py | 3 +- zeroconf/_core.py | 3 +- zeroconf/_dns.py | 3 +- zeroconf/_handlers.py | 3 +- zeroconf/_protocol/__init__.py | 51 +++ zeroconf/_protocol/incoming.py | 302 +++++++++++++++++ .../{_protocol.py => _protocol/outgoing.py} | 316 +----------------- zeroconf/_services/browser.py | 2 +- zeroconf/_services/info.py | 2 +- 11 files changed, 377 insertions(+), 317 deletions(-) create mode 100644 zeroconf/_protocol/__init__.py create mode 100644 zeroconf/_protocol/incoming.py rename zeroconf/{_protocol.py => _protocol/outgoing.py} (60%) diff --git a/setup.py b/setup.py index 0ad299fb..41e74842 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ author='Paul Scott-Murphy, William McBrine, Jakub Stasiak', url='https://github.com/jstasiak/python-zeroconf', package_data={"zeroconf": ["py.typed"]}, - packages=["zeroconf", "zeroconf._services", "zeroconf._utils"], + packages=["zeroconf", "zeroconf._protocol", "zeroconf._services", "zeroconf._utils"], platforms=['unix', 'linux', 'osx'], license='LGPL', zip_safe=False, diff --git a/tests/test_core.py b/tests/test_core.py index ba1effac..a5c22065 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -18,8 +18,9 @@ from unittest.mock import patch import zeroconf as r -from zeroconf import _core, _protocol, const, Zeroconf, current_time_millis +from zeroconf import _core, const, Zeroconf, current_time_millis from zeroconf.asyncio import AsyncZeroconf +from zeroconf._protocol import outgoing from . import has_working_ipv6, _clear_cache, _inject_response, _wait_for_start @@ -670,8 +671,8 @@ def test_guard_against_oversized_packets(): ) # We are patching to generate an oversized packet - with patch.object(_protocol, "_MAX_MSG_ABSOLUTE", 100000), patch.object( - _protocol, "_MAX_MSG_TYPICAL", 100000 + with patch.object(outgoing, "_MAX_MSG_ABSOLUTE", 100000), patch.object( + outgoing, "_MAX_MSG_TYPICAL", 100000 ): over_sized_packet = generated.packets()[0] assert len(over_sized_packet) > const._MAX_MSG_ABSOLUTE diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e8c03c64..e6f92d2a 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -46,7 +46,8 @@ NonUniqueNameException, ServiceNameAlreadyRegistered, ) -from ._protocol import DNSIncoming, DNSOutgoing # noqa # import needed for backwards compat +from ._protocol.incoming import DNSIncoming # noqa # import needed for backwards compat +from ._protocol.outgoing import DNSOutgoing # noqa # import needed for backwards compat from ._services import ( # noqa # import needed for backwards compat Signal, SignalRegistrationInterface, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index f4161877..15852ffb 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -41,7 +41,8 @@ ) from ._history import QuestionHistory from ._logger import QuietLogger, log -from ._protocol import DNSIncoming, DNSOutgoing +from ._protocol.incoming import DNSIncoming +from ._protocol.outgoing import DNSOutgoing from ._services import ServiceListener from ._services.browser import ServiceBrowser from ._services.info import ServiceInfo, instance_name_from_service_info diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index bb447b2f..7ef6b6a9 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -48,7 +48,8 @@ if TYPE_CHECKING: - from ._protocol import DNSIncoming, DNSOutgoing + from ._protocol.incoming import DNSIncoming + from ._protocol.outgoing import DNSOutgoing @enum.unique diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 22848409..5215f202 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -29,7 +29,8 @@ from ._dns import DNSAddress, DNSNsec, DNSPointer, DNSQuestion, DNSRRSet, DNSRecord from ._history import QuestionHistory from ._logger import log -from ._protocol import DNSIncoming, DNSOutgoing +from ._protocol.incoming import DNSIncoming +from ._protocol.outgoing import DNSOutgoing from ._services.info import ServiceInfo from ._services.registry import ServiceRegistry from ._updates import RecordUpdate, RecordUpdateListener diff --git a/zeroconf/_protocol/__init__.py b/zeroconf/_protocol/__init__.py new file mode 100644 index 00000000..360b599d --- /dev/null +++ b/zeroconf/_protocol/__init__.py @@ -0,0 +1,51 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +from ..const import ( + _FLAGS_QR_MASK, + _FLAGS_QR_QUERY, + _FLAGS_QR_RESPONSE, + _FLAGS_TC, +) + + +class DNSMessage: + """A base class for DNS messages.""" + + __slots__ = ('flags',) + + def __init__(self, flags: int) -> None: + """Construct a DNS message.""" + self.flags = flags + + def is_query(self) -> bool: + """Returns true if this is a query.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY + + def is_response(self) -> bool: + """Returns true if this is a response.""" + return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE + + @property + def truncated(self) -> bool: + """Returns true if this is a truncated.""" + return (self.flags & _FLAGS_TC) == _FLAGS_TC diff --git a/zeroconf/_protocol/incoming.py b/zeroconf/_protocol/incoming.py new file mode 100644 index 00000000..bffbb4bc --- /dev/null +++ b/zeroconf/_protocol/incoming.py @@ -0,0 +1,302 @@ +""" Multicast DNS Service Discovery for Python, v0.14-wmcbrine + Copyright 2003 Paul Scott-Murphy, 2014 William McBrine + + This module provides a framework for the use of DNS Service Discovery + using IP multicast. + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 + USA +""" + +import struct +from typing import Callable, Dict, List, Optional, Set, cast + +from . import DNSMessage +from .._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText +from .._exceptions import IncomingDecodeError +from .._logger import QuietLogger, log +from .._utils.time import current_time_millis +from ..const import ( + _TYPES, + _TYPE_A, + _TYPE_AAAA, + _TYPE_CNAME, + _TYPE_HINFO, + _TYPE_NSEC, + _TYPE_PTR, + _TYPE_SRV, + _TYPE_TXT, +) + +DNS_COMPRESSION_HEADER_LEN = 1 +DNS_COMPRESSION_POINTER_LEN = 2 +MAX_DNS_LABELS = 128 +MAX_NAME_LENGTH = 253 + +DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) + + +class DNSIncoming(DNSMessage, QuietLogger): + + """Object representation of an incoming DNS packet""" + + __slots__ = ( + 'offset', + 'data', + 'data_len', + 'name_cache', + 'questions', + '_answers', + 'id', + 'num_questions', + 'num_answers', + 'num_authorities', + 'num_additionals', + 'valid', + 'now', + 'scope_id', + ) + + def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None: + """Constructor from string holding bytes of packet""" + super().__init__(0) + self.offset = 0 + self.data = data + self.data_len = len(data) + self.name_cache: Dict[int, List[str]] = {} + self.questions: List[DNSQuestion] = [] + self._answers: List[DNSRecord] = [] + self.id = 0 + self.num_questions = 0 + self.num_answers = 0 + self.num_authorities = 0 + self.num_additionals = 0 + self.valid = False + self._read_others = False + self.now = now or current_time_millis() + self.scope_id = scope_id + self._parse_data(self._initial_parse) + + def _initial_parse(self) -> None: + """Parse the data needed to initalize the packet object.""" + self.read_header() + self.read_questions() + if not self.num_questions: + self.read_others() + self.valid = True + + def _parse_data(self, parser_call: Callable) -> None: + """Parse part of the packet and catch exceptions.""" + try: + parser_call() + except DECODE_EXCEPTIONS: + self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, self.data) + + @property + def answers(self) -> List[DNSRecord]: + """Answers in the packet.""" + if not self._read_others: + self._parse_data(self.read_others) + return self._answers + + def __repr__(self) -> str: + return '' % ', '.join( + [ + 'id=%s' % self.id, + 'flags=%s' % self.flags, + 'truncated=%s' % self.truncated, + 'n_q=%s' % self.num_questions, + 'n_ans=%s' % self.num_answers, + 'n_auth=%s' % self.num_authorities, + 'n_add=%s' % self.num_additionals, + 'questions=%s' % self.questions, + 'answers=%s' % self.answers, + ] + ) + + def unpack(self, format_: bytes, length: int) -> tuple: + self.offset += length + return struct.unpack(format_, self.data[self.offset - length : self.offset]) + + def read_header(self) -> None: + """Reads header portion of packet""" + ( + self.id, + self.flags, + self.num_questions, + self.num_answers, + self.num_authorities, + self.num_additionals, + ) = self.unpack(b'!6H', 12) + + def read_questions(self) -> None: + """Reads questions section of packet""" + self.questions = [ + DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions) + ] + + def read_character_string(self) -> bytes: + """Reads a character string from the packet""" + length = self.data[self.offset] + self.offset += 1 + return self.read_string(length) + + def read_string(self, length: int) -> bytes: + """Reads a string of a given length from the packet""" + info = self.data[self.offset : self.offset + length] + self.offset += length + return info + + def read_unsigned_short(self) -> int: + """Reads an unsigned short from the packet""" + return cast(int, self.unpack(b'!H', 2)[0]) + + def read_others(self) -> None: + """Reads the answers, authorities and additionals section of the + packet""" + self._read_others = True + n = self.num_answers + self.num_authorities + self.num_additionals + for _ in range(n): + domain = self.read_name() + type_, class_, ttl, length = self.unpack(b'!HHiH', 10) + end = self.offset + length + rec = None + try: + rec = self.read_record(domain, type_, class_, ttl, length) + except DECODE_EXCEPTIONS: + # Skip records that fail to decode if we know the length + # If the packet is really corrupt read_name and the unpack + # above would fail and hit the exception catch in read_others + self.offset = end + log.debug( + 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', + domain, + _TYPES.get(type_, type_), + self.offset, + self.data, + exc_info=True, + ) + if rec is not None: + self._answers.append(rec) + + def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: + """Read known records types and skip unknown ones.""" + if type_ == _TYPE_A: + return DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) + if type_ in (_TYPE_CNAME, _TYPE_PTR): + return DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) + if type_ == _TYPE_TXT: + return DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) + if type_ == _TYPE_SRV: + return DNSService( + domain, + type_, + class_, + ttl, + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_unsigned_short(), + self.read_name(), + self.now, + ) + if type_ == _TYPE_HINFO: + return DNSHinfo( + domain, + type_, + class_, + ttl, + self.read_character_string().decode('utf-8'), + self.read_character_string().decode('utf-8'), + self.now, + ) + if type_ == _TYPE_AAAA: + return DNSAddress( + domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id + ) + if type_ == _TYPE_NSEC: + name_start = self.offset + return DNSNsec( + domain, + type_, + class_, + ttl, + self.read_name(), + self.read_bitmap(name_start + length), + self.now, + ) + # Try to ignore types we don't know about + # Skip the payload for the resource record so the next + # records can be parsed correctly + self.offset += length + return None + + def read_bitmap(self, end: int) -> List[int]: + """Reads an NSEC bitmap from the packet.""" + rdtypes = [] + while self.offset < end: + window = self.data[self.offset] + bitmap_length = self.data[self.offset + 1] + for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]): + for bit in range(0, 8): + if byte & (0x80 >> bit): + rdtypes.append(bit + window * 256 + i * 8) + self.offset += 2 + bitmap_length + return rdtypes + + def read_name(self) -> str: + """Reads a domain name from the packet.""" + labels: List[str] = [] + seen_pointers: Set[int] = set() + self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers) + name = ".".join(labels) + "." + if len(name) > MAX_NAME_LENGTH: + raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") + return name + + def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int: + # This is a tight loop that is called frequently, small optimizations can make a difference. + while off < self.data_len: + length = self.data[off] + if length == 0: + return off + DNS_COMPRESSION_HEADER_LEN + + if length < 0x40: + label_idx = off + DNS_COMPRESSION_HEADER_LEN + labels.append(str(self.data[label_idx : label_idx + length], 'utf-8', 'replace')) + off += DNS_COMPRESSION_HEADER_LEN + length + continue + + if length < 0xC0: + raise IncomingDecodeError(f"DNS compression type {length} is unknown at {off}") + + # We have a DNS compression pointer + link = (length & 0x3F) * 256 + self.data[off + 1] + if link > self.data_len: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet") + if link == off: + raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself") + if link in seen_pointers: + raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again") + seen_pointers.add(link) + linked_labels = self.name_cache.get(link, []) + if not linked_labels: + self._decode_labels_at_offset(link, linked_labels, seen_pointers) + self.name_cache[link] = linked_labels + labels.extend(linked_labels) + if len(labels) > MAX_DNS_LABELS: + raise IncomingDecodeError(f"Maximum dns labels reached while processing pointer at {off}") + return off + DNS_COMPRESSION_POINTER_LEN + + raise IncomingDecodeError("Corrupt packet received while decoding name") diff --git a/zeroconf/_protocol.py b/zeroconf/_protocol/outgoing.py similarity index 60% rename from zeroconf/_protocol.py rename to zeroconf/_protocol/outgoing.py index b6ef7bcb..21ff4b64 100644 --- a/zeroconf/_protocol.py +++ b/zeroconf/_protocol/outgoing.py @@ -22,320 +22,22 @@ import enum import struct -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Tuple, Union, cast - - -from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText -from ._exceptions import IncomingDecodeError, NamePartTooLongException -from ._logger import QuietLogger, log -from ._utils.time import current_time_millis -from .const import ( +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from . import DNSMessage +from .incoming import DNSIncoming +from .._cache import DNSCache +from .._dns import DNSPointer, DNSQuestion, DNSRecord +from .._exceptions import NamePartTooLongException +from .._logger import log +from ..const import ( _CLASS_UNIQUE, _DNS_PACKET_HEADER_LEN, - _FLAGS_QR_MASK, - _FLAGS_QR_QUERY, - _FLAGS_QR_RESPONSE, _FLAGS_TC, _MAX_MSG_ABSOLUTE, _MAX_MSG_TYPICAL, - _TYPES, - _TYPE_A, - _TYPE_AAAA, - _TYPE_CNAME, - _TYPE_HINFO, - _TYPE_NSEC, - _TYPE_PTR, - _TYPE_SRV, - _TYPE_TXT, ) -DNS_COMPRESSION_HEADER_LEN = 1 -DNS_COMPRESSION_POINTER_LEN = 2 -MAX_DNS_LABELS = 128 -MAX_NAME_LENGTH = 253 - -DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) - -if TYPE_CHECKING: - from ._cache import DNSCache - - -class DNSMessage: - """A base class for DNS messages.""" - - __slots__ = ('flags',) - - def __init__(self, flags: int) -> None: - """Construct a DNS message.""" - self.flags = flags - - def is_query(self) -> bool: - """Returns true if this is a query.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_QUERY - - def is_response(self) -> bool: - """Returns true if this is a response.""" - return (self.flags & _FLAGS_QR_MASK) == _FLAGS_QR_RESPONSE - - @property - def truncated(self) -> bool: - """Returns true if this is a truncated.""" - return (self.flags & _FLAGS_TC) == _FLAGS_TC - - -class DNSIncoming(DNSMessage, QuietLogger): - - """Object representation of an incoming DNS packet""" - - __slots__ = ( - 'offset', - 'data', - 'data_len', - 'name_cache', - 'questions', - '_answers', - 'id', - 'num_questions', - 'num_answers', - 'num_authorities', - 'num_additionals', - 'valid', - 'now', - 'scope_id', - ) - - def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None: - """Constructor from string holding bytes of packet""" - super().__init__(0) - self.offset = 0 - self.data = data - self.data_len = len(data) - self.name_cache: Dict[int, List[str]] = {} - self.questions: List[DNSQuestion] = [] - self._answers: List[DNSRecord] = [] - self.id = 0 - self.num_questions = 0 - self.num_answers = 0 - self.num_authorities = 0 - self.num_additionals = 0 - self.valid = False - self._read_others = False - self.now = now or current_time_millis() - self.scope_id = scope_id - self._parse_data(self._initial_parse) - - def _initial_parse(self) -> None: - """Parse the data needed to initalize the packet object.""" - self.read_header() - self.read_questions() - if not self.num_questions: - self.read_others() - self.valid = True - - def _parse_data(self, parser_call: Callable) -> None: - """Parse part of the packet and catch exceptions.""" - try: - parser_call() - except DECODE_EXCEPTIONS: - self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, self.data) - - @property - def answers(self) -> List[DNSRecord]: - """Answers in the packet.""" - if not self._read_others: - self._parse_data(self.read_others) - return self._answers - - def __repr__(self) -> str: - return '' % ', '.join( - [ - 'id=%s' % self.id, - 'flags=%s' % self.flags, - 'truncated=%s' % self.truncated, - 'n_q=%s' % self.num_questions, - 'n_ans=%s' % self.num_answers, - 'n_auth=%s' % self.num_authorities, - 'n_add=%s' % self.num_additionals, - 'questions=%s' % self.questions, - 'answers=%s' % self.answers, - ] - ) - - def unpack(self, format_: bytes, length: int) -> tuple: - self.offset += length - return struct.unpack(format_, self.data[self.offset - length : self.offset]) - - def read_header(self) -> None: - """Reads header portion of packet""" - ( - self.id, - self.flags, - self.num_questions, - self.num_answers, - self.num_authorities, - self.num_additionals, - ) = self.unpack(b'!6H', 12) - - def read_questions(self) -> None: - """Reads questions section of packet""" - self.questions = [ - DNSQuestion(self.read_name(), *self.unpack(b'!HH', 4)) for _ in range(self.num_questions) - ] - - def read_character_string(self) -> bytes: - """Reads a character string from the packet""" - length = self.data[self.offset] - self.offset += 1 - return self.read_string(length) - - def read_string(self, length: int) -> bytes: - """Reads a string of a given length from the packet""" - info = self.data[self.offset : self.offset + length] - self.offset += length - return info - - def read_unsigned_short(self) -> int: - """Reads an unsigned short from the packet""" - return cast(int, self.unpack(b'!H', 2)[0]) - - def read_others(self) -> None: - """Reads the answers, authorities and additionals section of the - packet""" - self._read_others = True - n = self.num_answers + self.num_authorities + self.num_additionals - for _ in range(n): - domain = self.read_name() - type_, class_, ttl, length = self.unpack(b'!HHiH', 10) - end = self.offset + length - rec = None - try: - rec = self.read_record(domain, type_, class_, ttl, length) - except DECODE_EXCEPTIONS: - # Skip records that fail to decode if we know the length - # If the packet is really corrupt read_name and the unpack - # above would fail and hit the exception catch in read_others - self.offset = end - log.debug( - 'Unable to parse; skipping record for %s with type %s at offset %d while unpacking %r', - domain, - _TYPES.get(type_, type_), - self.offset, - self.data, - exc_info=True, - ) - if rec is not None: - self._answers.append(rec) - - def read_record(self, domain: str, type_: int, class_: int, ttl: int, length: int) -> Optional[DNSRecord]: - """Read known records types and skip unknown ones.""" - if type_ == _TYPE_A: - return DNSAddress(domain, type_, class_, ttl, self.read_string(4), created=self.now) - if type_ in (_TYPE_CNAME, _TYPE_PTR): - return DNSPointer(domain, type_, class_, ttl, self.read_name(), self.now) - if type_ == _TYPE_TXT: - return DNSText(domain, type_, class_, ttl, self.read_string(length), self.now) - if type_ == _TYPE_SRV: - return DNSService( - domain, - type_, - class_, - ttl, - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_unsigned_short(), - self.read_name(), - self.now, - ) - if type_ == _TYPE_HINFO: - return DNSHinfo( - domain, - type_, - class_, - ttl, - self.read_character_string().decode('utf-8'), - self.read_character_string().decode('utf-8'), - self.now, - ) - if type_ == _TYPE_AAAA: - return DNSAddress( - domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id - ) - if type_ == _TYPE_NSEC: - name_start = self.offset - return DNSNsec( - domain, - type_, - class_, - ttl, - self.read_name(), - self.read_bitmap(name_start + length), - self.now, - ) - # Try to ignore types we don't know about - # Skip the payload for the resource record so the next - # records can be parsed correctly - self.offset += length - return None - - def read_bitmap(self, end: int) -> List[int]: - """Reads an NSEC bitmap from the packet.""" - rdtypes = [] - while self.offset < end: - window = self.data[self.offset] - bitmap_length = self.data[self.offset + 1] - for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]): - for bit in range(0, 8): - if byte & (0x80 >> bit): - rdtypes.append(bit + window * 256 + i * 8) - self.offset += 2 + bitmap_length - return rdtypes - - def read_name(self) -> str: - """Reads a domain name from the packet.""" - labels: List[str] = [] - seen_pointers: Set[int] = set() - self.offset = self._decode_labels_at_offset(self.offset, labels, seen_pointers) - name = ".".join(labels) + "." - if len(name) > MAX_NAME_LENGTH: - raise IncomingDecodeError(f"DNS name {name} exceeds maximum length of {MAX_NAME_LENGTH}") - return name - - def _decode_labels_at_offset(self, off: int, labels: List[str], seen_pointers: Set[int]) -> int: - # This is a tight loop that is called frequently, small optimizations can make a difference. - while off < self.data_len: - length = self.data[off] - if length == 0: - return off + DNS_COMPRESSION_HEADER_LEN - - if length < 0x40: - label_idx = off + DNS_COMPRESSION_HEADER_LEN - labels.append(str(self.data[label_idx : label_idx + length], 'utf-8', 'replace')) - off += DNS_COMPRESSION_HEADER_LEN + length - continue - - if length < 0xC0: - raise IncomingDecodeError(f"DNS compression type {length} is unknown at {off}") - - # We have a DNS compression pointer - link = (length & 0x3F) * 256 + self.data[off + 1] - if link > self.data_len: - raise IncomingDecodeError(f"DNS compression pointer at {off} points to {link} beyond packet") - if link == off: - raise IncomingDecodeError(f"DNS compression pointer at {off} points to itself") - if link in seen_pointers: - raise IncomingDecodeError(f"DNS compression pointer at {off} was seen again") - seen_pointers.add(link) - linked_labels = self.name_cache.get(link, []) - if not linked_labels: - self._decode_labels_at_offset(link, linked_labels, seen_pointers) - self.name_cache[link] = linked_labels - labels.extend(linked_labels) - if len(labels) > MAX_DNS_LABELS: - raise IncomingDecodeError(f"Maximum dns labels reached while processing pointer at {off}") - return off + DNS_COMPRESSION_POINTER_LEN - - raise IncomingDecodeError("Corrupt packet received while decoding name") - class DNSOutgoing(DNSMessage): diff --git a/zeroconf/_services/browser.py b/zeroconf/_services/browser.py index d47e42e9..f6448fd2 100644 --- a/zeroconf/_services/browser.py +++ b/zeroconf/_services/browser.py @@ -30,7 +30,7 @@ from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord from .._logger import log -from .._protocol import DNSOutgoing +from .._protocol.outgoing import DNSOutgoing from .._services import ( ServiceListener, ServiceStateChange, diff --git a/zeroconf/_services/info.py b/zeroconf/_services/info.py index 7aaea1b6..beaf0678 100644 --- a/zeroconf/_services/info.py +++ b/zeroconf/_services/info.py @@ -27,7 +27,7 @@ from .._dns import DNSAddress, DNSPointer, DNSQuestionType, DNSRecord, DNSService, DNSText from .._exceptions import BadTypeInNameException -from .._protocol import DNSOutgoing +from .._protocol.outgoing import DNSOutgoing from .._updates import RecordUpdate, RecordUpdateListener from .._utils.asyncio import get_running_loop, run_coro_with_timeout from .._utils.name import service_type_name From aebabe95c59e34f703307340e087b3eab5339a06 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 18 Sep 2021 17:04:45 -1000 Subject: [PATCH 592/608] Update changelog for 0.36.5 (#989) --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index 3c6f36c2..6b4ee99c 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,11 @@ See examples directory for more. Changelog ========= +0.36.5 +====== + +* Reduced memory usage for incoming and outgoing packets (#987) @bdraco + 0.36.4 ====== From 34f4a26c9254d6002bdccb1a003d9822a8798c04 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sat, 18 Sep 2021 22:05:46 -0500 Subject: [PATCH 593/608] =?UTF-8?q?Bump=20version:=200.36.4=20=E2=86=92=20?= =?UTF-8?q?0.36.5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index b467a89c..36703f95 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.4 +current_version = 0.36.5 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index e6f92d2a..055d0439 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -80,7 +80,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.4' +__version__ = '0.36.5' __license__ = 'LGPL' From 1887c554b3f9d0b90a1c01798d7f06a7e4de6900 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Sep 2021 12:49:47 -1000 Subject: [PATCH 594/608] Simplify the can_send_to check (#990) --- zeroconf/__init__.py | 1 - zeroconf/_core.py | 9 +++++---- zeroconf/_utils/net.py | 10 +++++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 055d0439..5c04eb26 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -67,7 +67,6 @@ from ._utils.name import service_type_name # noqa # import needed for backwards compat from ._utils.net import ( # noqa # import needed for backwards compat add_multicast_member, - can_send_to, autodetect_ip_version, create_sockets, get_all_addresses_v6, diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 15852ffb..74b1828f 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -836,12 +836,13 @@ def _async_send_transport( v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), ) -> None: s = transport.get_extra_info('socket') + ipv6_socket = s.family == socket.AF_INET6 if addr is None: - real_addr = _MDNS_ADDR6 if s.family == socket.AF_INET6 else _MDNS_ADDR + real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR else: real_addr = addr - if not can_send_to(s, real_addr): - return + if not can_send_to(ipv6_socket, real_addr): + return log.debug( 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', real_addr, @@ -855,7 +856,7 @@ def _async_send_transport( ) # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families - if s.family == socket.AF_INET6 and not v6_flow_scope: + if ipv6_socket and not v6_flow_scope: _, _, sock_flowinfo, sock_scopeid = s.getsockname() v6_flow_scope = (sock_flowinfo, sock_scopeid) transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) diff --git a/zeroconf/_utils/net.py b/zeroconf/_utils/net.py index bfae9db4..c53ec978 100644 --- a/zeroconf/_utils/net.py +++ b/zeroconf/_utils/net.py @@ -379,9 +379,13 @@ def get_errno(e: Exception) -> int: return cast(int, e.args[0]) -def can_send_to(sock: socket.socket, address: str) -> bool: - addr = ipaddress.ip_address(address) - return cast(bool, addr.version == 6 if sock.family == socket.AF_INET6 else addr.version == 4) +def can_send_to(ipv6_socket: bool, address: str) -> bool: + """Check if the address type matches the socket type. + + This function does not validate if the address is a valid + ipv6 or ipv4 address. + """ + return ":" in address if ipv6_socket else ":" not in address def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion: From 92f5f4a80b8a8e50df5ca06e3cc45480dc39b504 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Sep 2021 12:51:38 -1000 Subject: [PATCH 595/608] Update changelog for 0.36.6 (#991) --- README.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.rst b/README.rst index 6b4ee99c..f936c3fe 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,11 @@ See examples directory for more. Changelog ========= +0.36.6 +====== + +* Improve performance of sending outgoing packets (#990) @bdraco + 0.36.5 ====== From 29f995fd3c09604f37980e74f2785b1a451da089 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Sep 2021 12:52:45 -1000 Subject: [PATCH 596/608] Fix tense of 0.36.6 changelog (#992) --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f936c3fe..01a81c20 100644 --- a/README.rst +++ b/README.rst @@ -141,7 +141,7 @@ Changelog 0.36.6 ====== -* Improve performance of sending outgoing packets (#990) @bdraco +* Improved performance of sending outgoing packets (#990) @bdraco 0.36.5 ====== From 0327a068250c85f3ff84d3f0b809b51f83321c47 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 19 Sep 2021 17:52:59 -0500 Subject: [PATCH 597/608] =?UTF-8?q?Bump=20version:=200.36.5=20=E2=86=92=20?= =?UTF-8?q?0.36.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 36703f95..e8151de2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.5 +current_version = 0.36.6 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 5c04eb26..9a8ff7c7 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.5' +__version__ = '0.36.6' __license__ = 'LGPL' From 93ddf7cf9b47d7ff1e341b6c2875254b6f00eef1 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 18:51:23 -0500 Subject: [PATCH 598/608] Flush CI cache (#995) --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 3686d617..88980ff2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -# version: 1.0 +# version: 1.1 .PHONY: all virtualenv MAX_LINE_LENGTH=110 From 762236547d4838f2b6a94cfa20221dfdd03e9b94 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 22:27:48 -0500 Subject: [PATCH 599/608] Refactor service registry to avoid use of getattr (#996) --- zeroconf/_services/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/zeroconf/_services/registry.py b/zeroconf/_services/registry.py index 4e64c8d7..203b3b39 100644 --- a/zeroconf/_services/registry.py +++ b/zeroconf/_services/registry.py @@ -69,15 +69,15 @@ def async_get_types(self) -> List[str]: def async_get_infos_type(self, type_: str) -> List[ServiceInfo]: """Return all ServiceInfo matching type.""" - return self._async_get_by_index("types", type_) + return self._async_get_by_index(self.types, type_) def async_get_infos_server(self, server: str) -> List[ServiceInfo]: """Return all ServiceInfo matching server.""" - return self._async_get_by_index("servers", server) + return self._async_get_by_index(self.servers, server) - def _async_get_by_index(self, attr: str, key: str) -> List[ServiceInfo]: + def _async_get_by_index(self, records: Dict[str, List], key: str) -> List[ServiceInfo]: """Return all ServiceInfo matching the index.""" - return [self._services[name] for name in getattr(self, attr).get(key.lower(), [])] + return [self._services[name] for name in records.get(key.lower(), [])] def _add(self, info: ServiceInfo) -> None: """Add a new service under the lock.""" From 7fa51de5b71d03470643a83004b9f6f8d4017214 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 22:35:29 -0500 Subject: [PATCH 600/608] Reduce overhead to compare dns records (#997) --- zeroconf/_dns.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 7ef6b6a9..35594b37 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -127,7 +127,7 @@ def __hash__(self) -> int: def __eq__(self, other: Any) -> bool: """Tests equality on dns question.""" - return isinstance(other, DNSQuestion) and DNSEntry.__eq__(self, other) + return isinstance(other, DNSQuestion) and dns_entry_matches(other, self.key, self.type, self.class_) @property def max_size(self) -> int: @@ -260,7 +260,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSAddress) and self.address == other.address and self.scope_id == other.scope_id - and DNSEntry.__eq__(self, other) + and dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -304,7 +304,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSHinfo) and self.cpu == other.cpu and self.os == other.os - and DNSEntry.__eq__(self, other) + and dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -345,7 +345,11 @@ def write(self, out: 'DNSOutgoing') -> None: def __eq__(self, other: Any) -> bool: """Tests equality on alias""" - return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other) + return ( + isinstance(other, DNSPointer) + and self.alias == other.alias + and dns_entry_matches(other, self.key, self.type, self.class_) + ) def __hash__(self) -> int: """Hash to compare like DNSPointer.""" @@ -380,7 +384,11 @@ def __hash__(self) -> int: def __eq__(self, other: Any) -> bool: """Tests equality on text""" - return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other) + return ( + isinstance(other, DNSText) + and self.text == other.text + and dns_entry_matches(other, self.key, self.type, self.class_) + ) def __repr__(self) -> str: """String representation""" @@ -429,7 +437,7 @@ def __eq__(self, other: Any) -> bool: and self.weight == other.weight and self.port == other.port and self.server == other.server - and DNSEntry.__eq__(self, other) + and dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: @@ -484,7 +492,7 @@ def __eq__(self, other: Any) -> bool: isinstance(other, DNSNsec) and self.next_name == other.next_name and self.rdtypes == other.rdtypes - and DNSEntry.__eq__(self, other) + and dns_entry_matches(other, self.key, self.type, self.class_) ) def __hash__(self) -> int: From 7df7e4a68e33c3e3a5bddf0168e248a4542a788f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 22:45:48 -0500 Subject: [PATCH 601/608] Reduce logging overhead (#994) --- tests/test_core.py | 34 +++++++++++++++++++++++++++ zeroconf/_core.py | 57 +++++++++++++++++++++++++--------------------- 2 files changed, 65 insertions(+), 26 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index a5c22065..eab769be 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -398,6 +398,40 @@ def test_register_service_with_custom_ttl(): zc.close() +def test_logging_packets(caplog): + """Test packets are only logged with debug logging.""" + + # instantiate a zeroconf instance + zc = Zeroconf(interfaces=['127.0.0.1']) + + # start a browser + type_ = "_logging._tcp.local." + name = "TLD" + info_service = r.ServiceInfo( + type_, + f'{name}.{type_}', + 80, + 0, + 0, + {'path': '/~paulsm/'}, + "ash-90.local.", + addresses=[socket.inet_aton("10.0.1.2")], + ) + + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + caplog.clear() + zc.register_service(info_service, ttl=3000) + assert "Sending to" in caplog.text + assert zc.cache.get(info_service.dns_pointer()).ttl == 3000 + logging.getLogger('zeroconf').setLevel(logging.INFO) + caplog.clear() + zc.unregister_service(info_service) + assert "Sending to" not in caplog.text + logging.getLogger('zeroconf').setLevel(logging.DEBUG) + + zc.close() + + def test_get_service_info_failure_path(): """Verify get_service_info return None when the underlying call returns False.""" zc = Zeroconf(interfaces=['127.0.0.1']) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 74b1828f..0d59b3d7 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -22,6 +22,7 @@ import asyncio import itertools +import logging import random import socket import sys @@ -217,6 +218,7 @@ def __init__(self, zc: 'Zeroconf') -> None: self.last_time: float = 0 self.transport: Optional[asyncio.DatagramTransport] = None self.sock_name: Optional[str] = None + self.sock_description: Optional[str] = None self.sock_fileno: Optional[int] = None self._deferred: Dict[str, List[DNSIncoming]] = {} self._timers: Dict[str, asyncio.TimerHandle] = {} @@ -236,6 +238,8 @@ def datagram_received( ) -> None: assert self.transport is not None v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () + data_len = len(data) + if len(addrs) == 2: # https://github.com/python/mypy/issues/1178 addr, port = addrs # type: ignore @@ -253,19 +257,19 @@ def datagram_received( 'Ignoring duplicate message received from %r:%r [socket %s] (%d bytes) as [%r]', addr, port, - self._socket_description, - len(data), + self.sock_description, + data_len, data, ) return - if len(data) > _MAX_MSG_ABSOLUTE: + if data_len > _MAX_MSG_ABSOLUTE: # Guard against oversized packets to ensure bad implementations cannot overwhelm # the system. log.debug( "Discarding incoming packet with length %s, which is larger " "than the absolute maximum size of %s", - len(data), + data_len, _MAX_MSG_ABSOLUTE, ) return @@ -276,9 +280,9 @@ def datagram_received( 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', addr, port, - self._socket_description, + self.sock_description, msg, - len(data), + data_len, data, ) else: @@ -286,8 +290,8 @@ def datagram_received( 'Received from %r:%r [socket %s]: (%d bytes) [%r]', addr, port, - self._socket_description, - len(data), + self.sock_description, + data_len, data, ) return @@ -346,24 +350,20 @@ def _respond_query( self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) - @property - def _socket_description(self) -> str: - """A human readable description of the socket.""" - return f"{self.sock_fileno} ({self.sock_name})" - def error_received(self, exc: Exception) -> None: """Likely socket closed or IPv6.""" # We preformat the message string with the socket as we want # log_exception_once to log a warrning message once PER EACH # different socket in case there are problems with multiple # sockets - msg_str = f"Error with socket {self._socket_description}): %s" + msg_str = f"Error with socket {self.sock_description}): %s" self.log_exception_once(exc, msg_str, exc) def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) self.sock_name = self.transport.get_extra_info('sockname') self.sock_fileno = self.transport.get_extra_info('socket').fileno() + self.sock_description = f"{self.sock_fileno} ({self.sock_name})" def connection_lost(self, exc: Optional[Exception]) -> None: """Handle connection lost.""" @@ -817,16 +817,20 @@ def async_send( # If no transport is specified, we send to all the ones # with the same address family transports = [transport] if transport else self.engine.senders + log_debug = log.isEnabledFor(logging.DEBUG) for packet_num, packet in enumerate(out.packets()): if len(packet) > _MAX_MSG_ABSOLUTE: self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return for send_transport in transports: - self._async_send_transport(send_transport, packet, packet_num, out, addr, port, v6_flow_scope) + self._async_send_transport( + log_debug, send_transport, packet, packet_num, out, addr, port, v6_flow_scope + ) def _async_send_transport( self, + log_debug: bool, transport: asyncio.DatagramTransport, packet: bytes, packet_num: int, @@ -843,17 +847,18 @@ def _async_send_transport( real_addr = addr if not can_send_to(ipv6_socket, real_addr): return - log.debug( - 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', - real_addr, - port or _MDNS_PORT, - s.fileno(), - transport.get_extra_info('sockname'), - len(packet), - packet_num + 1, - out, - packet, - ) + if log_debug: + log.debug( + 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', + real_addr, + port or _MDNS_PORT, + s.fileno(), + transport.get_extra_info('sockname'), + len(packet), + packet_num + 1, + out, + packet, + ) # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families if ipv6_socket and not v6_flow_scope: From b637846e7df3292d6dcdd38a8eb77b6fa3287c51 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 23:17:50 -0500 Subject: [PATCH 602/608] Improve log message when receiving an invalid or corrupt packet (#998) --- tests/test_protocol.py | 11 +++++++---- zeroconf/_core.py | 2 +- zeroconf/_protocol/incoming.py | 19 ++++++++++++++++--- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 6ad3303b..55dbbe4d 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -822,16 +822,18 @@ def test_dns_compression_invalid_skips_bad_name_compress_in_question(): assert len(parsed.questions) == 4 -def test_dns_compression_all_invalid(): +def test_dns_compression_all_invalid(caplog): """Test our wire parser can skip all invalid data.""" packet = ( b'\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00!roborock-vacuum-s5e_miio416' b'112328\x00\x00/\x80\x01\x00\x00\x00x\x00\t\xc0P\x00\x05@\x00\x00\x00\x00' ) - parsed = r.DNSIncoming(packet) + parsed = r.DNSIncoming(packet, ("2.4.5.4", 5353)) assert len(parsed.questions) == 0 assert len(parsed.answers) == 0 + assert " Unable to parse; skipping record" in caplog.text + def test_invalid_next_name_ignored(): """Test our wire parser does not throw an an invalid next name. @@ -918,15 +920,16 @@ def test_dns_compression_points_beyond_packet(): assert len(parsed.answers) == 1 -def test_dns_compression_generic_failure(): +def test_dns_compression_generic_failure(caplog): """Test our wire parser does not loop forever when dns compression is corrupt.""" packet = ( b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x06domain\x05local\x00\x00\x01' b'\x80\x01\x00\x00\x00\x01\x00\x04\xc0\xa8\xd0\x05-\x0c\x00\x01\x80\x01\x00\x00' b'\x00\x01\x00\x04\xc0\xa8\xd0\x06' ) - parsed = r.DNSIncoming(packet) + parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353)) assert len(parsed.answers) == 1 + assert "Received invalid packet from ('1.2.3.4', 5353)" in caplog.text def test_label_length_attack(): diff --git a/zeroconf/_core.py b/zeroconf/_core.py index 0d59b3d7..e609e4ad 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -274,7 +274,7 @@ def datagram_received( ) return - msg = DNSIncoming(data, scope, now) + msg = DNSIncoming(data, (addr, port), scope, now) if msg.valid: log.debug( 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', diff --git a/zeroconf/_protocol/incoming.py b/zeroconf/_protocol/incoming.py index bffbb4bc..6d7a6153 100644 --- a/zeroconf/_protocol/incoming.py +++ b/zeroconf/_protocol/incoming.py @@ -21,7 +21,7 @@ """ import struct -from typing import Callable, Dict, List, Optional, Set, cast +from typing import Callable, Dict, List, Optional, Set, Tuple, cast from . import DNSMessage from .._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText @@ -67,9 +67,16 @@ class DNSIncoming(DNSMessage, QuietLogger): 'valid', 'now', 'scope_id', + 'source', ) - def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[float] = None) -> None: + def __init__( + self, + data: bytes, + source: Optional[Tuple[str, int]] = None, + scope_id: Optional[int] = None, + now: Optional[float] = None, + ) -> None: """Constructor from string holding bytes of packet""" super().__init__(0) self.offset = 0 @@ -86,6 +93,7 @@ def __init__(self, data: bytes, scope_id: Optional[int] = None, now: Optional[fl self.valid = False self._read_others = False self.now = now or current_time_millis() + self.source = source self.scope_id = scope_id self._parse_data(self._initial_parse) @@ -102,7 +110,12 @@ def _parse_data(self, parser_call: Callable) -> None: try: parser_call() except DECODE_EXCEPTIONS: - self.log_exception_warning('Choked at offset %d while unpacking %r', self.offset, self.data) + self.log_exception_warning( + 'Received invalid packet from %s at offset %d while unpacking %r', + self.source, + self.offset, + self.data, + ) @property def answers(self) -> List[DNSRecord]: From d2853c31db9ece28fb258c4146ba61cf0e6a6592 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 23:21:42 -0500 Subject: [PATCH 603/608] Update changelog for 0.36.7 (#999) --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index 01a81c20..ea7c3d1c 100644 --- a/README.rst +++ b/README.rst @@ -138,6 +138,12 @@ See examples directory for more. Changelog ========= +0.36.7 +====== + +* Improved performance of responding to queries (#994) (#996) (#997) @bdraco +* Improved log message when receiving an invalid or corrupt packet (#998) @bdraco + 0.36.6 ====== From f44b40e26ea8872151ea9ee4762b95ca25790089 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 22 Sep 2021 23:22:00 -0500 Subject: [PATCH 604/608] =?UTF-8?q?Bump=20version:=200.36.6=20=E2=86=92=20?= =?UTF-8?q?0.36.7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.cfg | 2 +- zeroconf/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index e8151de2..67e50408 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.36.6 +current_version = 0.36.7 commit = True tag = True tag_name = {new_version} diff --git a/zeroconf/__init__.py b/zeroconf/__init__.py index 9a8ff7c7..4821cbb8 100644 --- a/zeroconf/__init__.py +++ b/zeroconf/__init__.py @@ -79,7 +79,7 @@ __author__ = 'Paul Scott-Murphy, William McBrine' __maintainer__ = 'Jakub Stasiak ' -__version__ = '0.36.6' +__version__ = '0.36.7' __license__ = 'LGPL' From 8e45ea943be6490b2217f0eb01501e12a5221c16 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Sep 2021 08:43:42 -0500 Subject: [PATCH 605/608] Remove unused code in zeroconf._core (#1001) - Breakout functions without self-use --- zeroconf/_core.py | 89 +++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/zeroconf/_core.py b/zeroconf/_core.py index e609e4ad..1575eba2 100644 --- a/zeroconf/_core.py +++ b/zeroconf/_core.py @@ -212,17 +212,16 @@ class AsyncListener(asyncio.Protocol, QuietLogger): It requires registration with an Engine object in order to have the read() method called when a socket is available for reading.""" + __slots__ = ('zc', 'data', 'last_time', 'transport', 'sock_description', '_deferred', '_timers') + def __init__(self, zc: 'Zeroconf') -> None: self.zc = zc self.data: Optional[bytes] = None self.last_time: float = 0 self.transport: Optional[asyncio.DatagramTransport] = None - self.sock_name: Optional[str] = None self.sock_description: Optional[str] = None - self.sock_fileno: Optional[int] = None self._deferred: Dict[str, List[DNSIncoming]] = {} self._timers: Dict[str, asyncio.TimerHandle] = {} - super().__init__() def suppress_duplicate_packet(self, data: bytes, now: float) -> bool: @@ -361,14 +360,52 @@ def error_received(self, exc: Exception) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = cast(asyncio.DatagramTransport, transport) - self.sock_name = self.transport.get_extra_info('sockname') - self.sock_fileno = self.transport.get_extra_info('socket').fileno() - self.sock_description = f"{self.sock_fileno} ({self.sock_name})" + sock_name = self.transport.get_extra_info('sockname') + sock_fileno = self.transport.get_extra_info('socket').fileno() + self.sock_description = f"{sock_fileno} ({sock_name})" def connection_lost(self, exc: Optional[Exception]) -> None: """Handle connection lost.""" +def async_send_with_transport( + log_debug: bool, + transport: asyncio.DatagramTransport, + packet: bytes, + packet_num: int, + out: DNSOutgoing, + addr: Optional[str], + port: int, + v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), +) -> None: + s = transport.get_extra_info('socket') + ipv6_socket = s.family == socket.AF_INET6 + if addr is None: + real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR + else: + real_addr = addr + if not can_send_to(ipv6_socket, real_addr): + return + if log_debug: + log.debug( + 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', + real_addr, + port or _MDNS_PORT, + s.fileno(), + transport.get_extra_info('sockname'), + len(packet), + packet_num + 1, + out, + packet, + ) + # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 + # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families + if ipv6_socket and not v6_flow_scope: + _, _, sock_flowinfo, sock_scopeid = s.getsockname() + v6_flow_scope = (sock_flowinfo, sock_scopeid) + transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) + + class Zeroconf(QuietLogger): """Implementation of Zeroconf Multicast DNS Service Discovery @@ -824,48 +861,10 @@ def async_send( self.log_warning_once("Dropping %r over-sized packet (%d bytes) %r", out, len(packet), packet) return for send_transport in transports: - self._async_send_transport( + async_send_with_transport( log_debug, send_transport, packet, packet_num, out, addr, port, v6_flow_scope ) - def _async_send_transport( - self, - log_debug: bool, - transport: asyncio.DatagramTransport, - packet: bytes, - packet_num: int, - out: DNSOutgoing, - addr: Optional[str], - port: int, - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), - ) -> None: - s = transport.get_extra_info('socket') - ipv6_socket = s.family == socket.AF_INET6 - if addr is None: - real_addr = _MDNS_ADDR6 if ipv6_socket else _MDNS_ADDR - else: - real_addr = addr - if not can_send_to(ipv6_socket, real_addr): - return - if log_debug: - log.debug( - 'Sending to (%s, %d) via [socket %s (%s)] (%d bytes #%d) %r as %r...', - real_addr, - port or _MDNS_PORT, - s.fileno(), - transport.get_extra_info('sockname'), - len(packet), - packet_num + 1, - out, - packet, - ) - # Get flowinfo and scopeid for the IPV6 socket to create a complete IPv6 - # address tuple: https://docs.python.org/3.6/library/socket.html#socket-families - if ipv6_socket and not v6_flow_scope: - _, _, sock_flowinfo, sock_scopeid = s.getsockname() - v6_flow_scope = (sock_flowinfo, sock_scopeid) - transport.sendto(packet, (real_addr, port or _MDNS_PORT, *v6_flow_scope)) - def _close(self) -> None: """Set global done and remove all service listeners.""" if self.done: From d3ed69107330f1a29f45d174caafdec1e894f666 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Sep 2021 08:44:08 -0500 Subject: [PATCH 606/608] Use more f-strings in zeroconf._dns (#1002) --- zeroconf/_dns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/_dns.py b/zeroconf/_dns.py index 35594b37..a551a7da 100644 --- a/zeroconf/_dns.py +++ b/zeroconf/_dns.py @@ -89,12 +89,12 @@ def __eq__(self, other: Any) -> bool: @staticmethod def get_class_(class_: int) -> str: """Class accessor""" - return _CLASSES.get(class_, "?(%s)" % class_) + return _CLASSES.get(class_, f"?({class_})") @staticmethod def get_type(t: int) -> str: """Type accessor""" - return _TYPES.get(t, "?(%s)" % t) + return _TYPES.get(t, f"?({t})") def entry_to_string(self, hdr: str, other: Optional[Union[bytes, str]]) -> str: """String representation with additional information""" From af4d082240a545ba3014eb7f1056c3b32ce2cb70 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Sep 2021 08:44:23 -0500 Subject: [PATCH 607/608] Breakout functions with no self-use in zeroconf._handlers (#1003) --- zeroconf/_handlers.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/zeroconf/_handlers.py b/zeroconf/_handlers.py index 5215f202..b4c31e2d 100644 --- a/zeroconf/_handlers.py +++ b/zeroconf/_handlers.py @@ -227,6 +227,19 @@ def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool: return bool(maybe_entry and self._now - maybe_entry.created < _ONE_SECOND) +def _get_address_and_nsec_records(service: ServiceInfo, now: float) -> Set[DNSRecord]: + """Build a set of address records and NSEC records for non-present record types.""" + seen_types: Set[int] = set() + records: Set[DNSRecord] = set() + for dns_address in service.dns_addresses(created=now): + seen_types.add(dns_address.type) + records.add(dns_address) + missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types + if missing_types: + records.add(construct_nsec_record(service.server, list(missing_types), now)) + return records + + class QueryHandler: """Query the ServiceRegistry.""" @@ -261,21 +274,9 @@ def _add_pointer_answers( if known_answers.suppresses(dns_pointer): continue additionals: Set[DNSRecord] = {service.dns_service(created=now), service.dns_text(created=now)} - additionals |= self._get_address_and_nsec_records(service, now) + additionals |= _get_address_and_nsec_records(service, now) answer_set[dns_pointer] = additionals - def _get_address_and_nsec_records(self, service: ServiceInfo, now: float) -> Set[DNSRecord]: - """Build a set of address records and NSEC records for non-present record types.""" - seen_types: Set[int] = set() - records: Set[DNSRecord] = set() - for dns_address in service.dns_addresses(created=now): - seen_types.add(dns_address.type) - records.add(dns_address) - missing_types: Set[int] = _ADDRESS_RECORD_TYPES - seen_types - if missing_types: - records.add(construct_nsec_record(service.server, list(missing_types), now)) - return records - def _add_address_answers( self, name: str, @@ -332,7 +333,7 @@ def _answer_question( # https://tools.ietf.org/html/rfc6763#section-12.2. dns_service = service.dns_service(created=now) if not known_answers.suppresses(dns_service): - answer_set[dns_service] = self._get_address_and_nsec_records(service, now) + answer_set[dns_service] = _get_address_and_nsec_records(service, now) if type_ in (_TYPE_TXT, _TYPE_ANY): dns_text = service.dns_text(created=now) if not known_answers.suppresses(dns_text): From 543558d0498ed03eb9dc4597c4c40484e16ee4e6 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 23 Sep 2021 08:44:32 -0500 Subject: [PATCH 608/608] Cleanup typing in zeroconf._protocol.outgoing (#1000) --- zeroconf/_protocol/outgoing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zeroconf/_protocol/outgoing.py b/zeroconf/_protocol/outgoing.py index 21ff4b64..59c8382e 100644 --- a/zeroconf/_protocol/outgoing.py +++ b/zeroconf/_protocol/outgoing.py @@ -158,7 +158,7 @@ def add_additional_answer(self, record: DNSRecord) -> None: self.additionals.append(record) def add_question_or_one_cache( - self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + self, cache: DNSCache, now: float, name: str, type_: int, class_: int ) -> None: """Add a question if it is not already cached.""" cached_entry = cache.get_by_details(name, type_, class_) @@ -168,7 +168,7 @@ def add_question_or_one_cache( self.add_answer_at_time(cached_entry, now) def add_question_or_all_cache( - self, cache: 'DNSCache', now: float, name: str, type_: int, class_: int + self, cache: DNSCache, now: float, name: str, type_: int, class_: int ) -> None: """Add a question if it is not already cached. This is currently only used for IPv6 addresses.