Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Do an AAAA lookup on SRV record targets
Browse files Browse the repository at this point in the history
Support SRV records which point at AAAA records, as well as A records.

Fixes #2405
  • Loading branch information
richvdh committed Sep 22, 2017
1 parent c94ab59 commit c2407d7
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 24 deletions.
116 changes: 96 additions & 20 deletions synapse/http/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket

from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
Expand All @@ -30,7 +31,10 @@

SERVER_CACHE = {}


# our record of an individual server which can be tried to reach a destination.
#
# "host" is actually a dotted-quad or ipv6 address string. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
Expand Down Expand Up @@ -219,9 +223,10 @@ def pick_server(self):
return self.default_server
else:
raise ConnectError(
"Not server available for %s" % self.service_name
"No server available for %s" % self.service_name
)

# look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
Expand All @@ -231,11 +236,22 @@ def pick_server(self):

total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)

for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
# XXX: this looks totally dubious:
#
# (a) we never reuse a server until we have been through
# all of the servers at the same priority, so if the
# weights are A: 100, B:1, we always do ABABAB instead of
# AAAA...AAAB (approximately).
#
# (b) After using all the servers at the lowest priority,
# we move onto the next priority. We should only use the
# second priority if servers at the top priority are
# unreachable.
#
del self.servers[index]
self.used_servers.append(server)
return server
Expand Down Expand Up @@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue

payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl

try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
hosts = yield _get_hosts_for_srv_record(
dns_client, str(payload.target)
)

for answer in answers:
if answer.type == dns.A and answer.payload:
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
for (ip, ttl) in hosts:
host_ttl = min(answer.ttl, ttl)

servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))

servers.sort()
cache[service_name] = list(servers)
Expand All @@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e

defer.returnValue(servers)


@defer.inlineCallbacks
def _get_hosts_for_srv_record(dns_client, host):
"""Look up each of the hosts in a SRV record
Args:
dns_client (twisted.names.dns.IResolver):
host (basestring): host to look up
Returns:
Deferred[list[(str, int)]]: a list of (host, ttl) pairs
"""
ip4_servers = []
ip6_servers = []

def cb(res):
# lookupAddress and lookupIP6Address return a three-tuple
# giving the answer, authority, and additional sections of the
# response.
#
# we only care about the answers.

return res[0]

def eb(res):
res.trap(DNSNameError)
return []

# no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
results = yield defer.gatherResults([d1, d2], consumeErrors=True)

for result in results:
for answer in result:
if not answer.payload:
continue

try:
if answer.type == dns.A:
ip = answer.payload.dottedQuad()
ip4_servers.append((ip, answer.ttl))
elif answer.type == dns.AAAA:
ip = socket.inet_ntop(
socket.AF_INET6, answer.payload.address,
)
ip6_servers.append((ip, answer.ttl))
else:
# the most likely candidate here is a CNAME record.
# rfc2782 says srvs may not point to aliases.
logger.warn(
"Ignoring unexpected DNS record type %s for %s",
answer.type, host,
)
continue
except Exception as e:
logger.warn("Ignoring invalid DNS response for %s: %s",
host, e)
continue

# keep the ipv4 results before the ipv6 results, mostly to match historical
# behaviour.
defer.returnValue(ip4_servers + ip6_servers)
26 changes: 22 additions & 4 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,17 @@
from tests.utils import MockClock


@unittest.DEBUG
class DnsTestCase(unittest.TestCase):

@defer.inlineCallbacks
def test_resolve(self):
dns_client_mock = Mock()

service_name = "test_service.examle.com"
service_name = "test_service.example.com"
host_name = "example.com"
ip_address = "127.0.0.1"
ip6_address = "::1"

answer_srv = dns.RRHeader(
type=dns.SRV,
Expand All @@ -48,8 +50,22 @@ def test_resolve(self):
)
)

dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
answer_aaaa = dns.RRHeader(
type=dns.AAAA,
payload=dns.Record_AAAA(
address=ip6_address,
)
)

dns_client_mock.lookupService.return_value = defer.succeed(
([answer_srv], None, None),
)
dns_client_mock.lookupAddress.return_value = defer.succeed(
([answer_a], None, None),
)
dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
([answer_aaaa], None, None),
)

cache = {}

Expand All @@ -59,10 +75,12 @@ def test_resolve(self):

dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)

self.assertEquals(len(servers), 1)
self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address)
self.assertEquals(servers[1].host, ip6_address)

@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
Expand Down

0 comments on commit c2407d7

Please sign in to comment.