/
srv_resolver.py
124 lines (107 loc) · 4.32 KB
/
srv_resolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright 2019-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
import ipaddress
import random
try:
from dns import resolver
_HAVE_DNSPYTHON = True
except ImportError:
_HAVE_DNSPYTHON = False
from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError
# dnspython can return bytes or str from various parts
# of its API depending on version. We always want str.
def maybe_decode(text):
if isinstance(text, bytes):
return text.decode()
return text
# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
def _resolve(*args, **kwargs):
if hasattr(resolver, 'resolve'):
# dnspython >= 2
return resolver.resolve(*args, **kwargs)
# dnspython 1.X
return resolver.query(*args, **kwargs)
_INVALID_HOST_MSG = (
"Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
"Did you mean to use 'mongodb://'?")
class _SrvResolver(object):
def __init__(self, fqdn,
connect_timeout, srv_service_name, srv_max_hosts=0):
self.__fqdn = fqdn
self.__srv = srv_service_name
self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
self.__srv_max_hosts = srv_max_hosts or 0
# Validate the fully qualified domain name.
try:
ipaddress.ip_address(fqdn)
raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
except ValueError:
pass
try:
self.__plist = self.__fqdn.split(".")[1:]
except Exception:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
self.__slen = len(self.__plist)
if self.__slen < 2:
raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
def get_options(self):
try:
results = _resolve(self.__fqdn, 'TXT',
lifetime=self.__connect_timeout)
except (resolver.NoAnswer, resolver.NXDOMAIN):
# No TXT records
return None
except Exception as exc:
raise ConfigurationError(str(exc))
if len(results) > 1:
raise ConfigurationError('Only one TXT record is supported')
return (
b'&'.join([b''.join(res.strings) for res in results])).decode(
'utf-8')
def _resolve_uri(self, encapsulate_errors):
try:
results = _resolve('_' + self.__srv + '._tcp.' + self.__fqdn,
'SRV', lifetime=self.__connect_timeout)
except Exception as exc:
if not encapsulate_errors:
# Raise the original error.
raise
# Else, raise all errors as ConfigurationError.
raise ConfigurationError(str(exc))
return results
def _get_srv_response_and_hosts(self, encapsulate_errors):
results = self._resolve_uri(encapsulate_errors)
# Construct address tuples
nodes = [
(maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
for res in results]
# Validate hosts
for node in nodes:
try:
nlist = node[0].split(".")[1:][-self.__slen:]
except Exception:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
if self.__plist != nlist:
raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
if self.__srv_max_hosts:
nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes)))
return results, nodes
def get_hosts(self):
_, nodes = self._get_srv_response_and_hosts(True)
return nodes
def get_hosts_and_min_ttl(self):
results, nodes = self._get_srv_response_and_hosts(False)
return nodes, results.rrset.ttl