Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion edgedb/asyncio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@ async def _connect_addr(self, addr):
else:
try:
tr, pr = await self._loop.create_connection(
self._protocol_factory, *addr, ssl=self._params.ssl_ctx
self._protocol_factory,
*addr,
ssl=self._params.ssl_ctx,
server_hostname=(
self._params.tls_server_name or addr[0]
),
)
except ssl.CertificateError as e:
raise con_utils.wrap_error(e) from e
Expand Down
2 changes: 2 additions & 0 deletions edgedb/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def __init__(
tls_ca: str = None,
tls_ca_file: str = None,
tls_security: str = None,
tls_server_name: str = None,
wait_until_available: int = 30,
timeout: int = 10,
**kwargs,
Expand All @@ -662,6 +663,7 @@ def __init__(
"tls_ca": tls_ca,
"tls_ca_file": tls_ca_file,
"tls_security": tls_security,
"tls_server_name": tls_server_name,
"wait_until_available": wait_until_available,
}

Expand Down
5 changes: 4 additions & 1 deletion edgedb/blocking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ async def _connect_addr(self, sock, addr, sa, deadline):
sock.settimeout(time_left)
try:
sock = self._params.ssl_ctx.wrap_socket(
sock, server_hostname=addr[0]
sock,
server_hostname=(
self._params.tls_server_name or addr[0]
),
)
except ssl.CertificateError as e:
raise con_utils.wrap_error(e) from e
Expand Down
32 changes: 31 additions & 1 deletion edgedb/con_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
re.ASCII,
)
CLOUD_INSTANCE_NAME_RE = re.compile(
r'^([A-Za-z0-9](?:-?[A-Za-z0-9])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$',
r'^([A-Za-z0-9_-](?:-?[A-Za-z0-9_])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$',
re.ASCII,
)
DSN_RE = re.compile(
Expand Down Expand Up @@ -202,6 +202,7 @@ class ResolvedConnectConfig:
_tls_ca_data = None
_tls_ca_data_source = None

_tls_server_name = None
_tls_security = None
_tls_security_source = None

Expand Down Expand Up @@ -254,6 +255,9 @@ def read_ca_file(file_path):

self._set_param('tls_ca_data', ca_file, source, read_ca_file)

def set_tls_server_name(self, ca_data, source):
self._set_param('tls_server_name', ca_data, source)

def set_tls_security(self, security, source):
self._set_param('tls_security', security, source,
_validate_tls_security)
Expand Down Expand Up @@ -308,6 +312,10 @@ def password(self):
def secret_key(self):
return self._secret_key

@property
def tls_server_name(self):
return self._tls_server_name

@property
def tls_security(self):
tls_security = self._tls_security or 'default'
Expand Down Expand Up @@ -555,6 +563,7 @@ def _parse_connect_dsn_and_args(
tls_ca,
tls_ca_file,
tls_security,
tls_server_name,
server_settings,
wait_until_available,
):
Expand Down Expand Up @@ -618,6 +627,10 @@ def _parse_connect_dsn_and_args(
(tls_security, '"tls_security" option')
if tls_security is not None else None
),
tls_server_name=(
(tls_server_name, '"tls_server_name" option')
if tls_server_name is not None else None
),
server_settings=(
(server_settings, '"server_settings" option')
if server_settings is not None else None
Expand Down Expand Up @@ -655,6 +668,7 @@ def _parse_connect_dsn_and_args(
env_secret_key = os.getenv('EDGEDB_SECRET_KEY')
env_tls_ca = os.getenv('EDGEDB_TLS_CA')
env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE')
env_tls_server_name = os.getenv('EDGEDB_TLS_SERVER_NAME')
env_tls_security = os.getenv('EDGEDB_CLIENT_TLS_SECURITY')
env_wait_until_available = os.getenv('EDGEDB_WAIT_UNTIL_AVAILABLE')

Expand Down Expand Up @@ -717,6 +731,11 @@ def _parse_connect_dsn_and_args(
'"EDGEDB_CLIENT_TLS_SECURITY" environment variable')
if env_tls_security is not None else None
),
tls_server_name=(
(env_tls_server_name,
'"EDGEDB_TLS_SERVER_NAME" environment variable')
if env_tls_server_name is not None else None
),
wait_until_available=(
(
env_wait_until_available,
Expand Down Expand Up @@ -924,6 +943,12 @@ def strip_leading_slash(str):
resolved_config._tls_ca_data, resolved_config.set_tls_ca_file
)

handle_dsn_part(
'tls_server_name', None,
resolved_config._tls_server_name,
resolved_config.set_tls_server_name
)

handle_dsn_part(
'tls_security', None,
resolved_config._tls_security,
Expand Down Expand Up @@ -1017,6 +1042,7 @@ def _resolve_config_options(
tls_ca=None,
tls_ca_file=None,
tls_security=None,
tls_server_name=None,
server_settings=None,
wait_until_available=None,
cloud_profile=None,
Expand Down Expand Up @@ -1051,6 +1077,8 @@ def _resolve_config_options(
resolved_config.set_tls_ca_data(*tls_ca)
if tls_security is not None:
resolved_config.set_tls_security(*tls_security)
if tls_server_name is not None:
resolved_config.set_tls_server_name(*tls_server_name)
if server_settings is not None:
resolved_config.add_server_settings(server_settings[0])
if wait_until_available is not None:
Expand Down Expand Up @@ -1178,6 +1206,7 @@ def parse_connect_arguments(
tls_ca,
tls_ca_file,
tls_security,
tls_server_name,
timeout,
command_timeout,
wait_until_available,
Expand Down Expand Up @@ -1211,6 +1240,7 @@ def parse_connect_arguments(
tls_ca=tls_ca,
tls_ca_file=tls_ca_file,
tls_security=tls_security,
tls_server_name=tls_server_name,
server_settings=server_settings,
wait_until_available=wait_until_available,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/shared-client-testcases
10 changes: 9 additions & 1 deletion tests/test_con_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def run_testcase(self, testcase):
tls_ca = opts.get('tlsCA')
tls_ca_file = opts.get('tlsCAFile')
tls_security = opts.get('tlsSecurity')
tls_server_name = opts.get('tlsServerName')
server_settings = opts.get('serverSettings')
wait_until_available = opts.get('waitUntilAvailable')

Expand Down Expand Up @@ -241,6 +242,7 @@ def mocked_open(filepath, *args, **kwargs):
tls_ca=tls_ca,
tls_ca_file=tls_ca_file,
tls_security=tls_security,
tls_server_name=tls_server_name,
timeout=timeout,
command_timeout=command_timeout,
server_settings=server_settings,
Expand All @@ -259,7 +261,10 @@ def mocked_open(filepath, *args, **kwargs):
'tlsCAData': connect_config._tls_ca_data,
'tlsSecurity': connect_config.tls_security,
'serverSettings': connect_config.server_settings,
'waitUntilAvailable': client_config.wait_until_available,
'waitUntilAvailable': float(
client_config.wait_until_available
),
'tlsServerName': connect_config.tls_server_name,
}

if expected is not None:
Expand Down Expand Up @@ -312,6 +317,7 @@ def test_test_connect_params_run_testcase_01(self):
'tlsSecurity': 'strict',
'serverSettings': {},
'waitUntilAvailable': 30,
'tlsServerName': None,
},
})

Expand All @@ -336,6 +342,7 @@ def test_test_connect_params_run_testcase_02(self):
'tlsSecurity': 'strict',
'serverSettings': {},
'waitUntilAvailable': 30,
'tlsServerName': None,
},
})

Expand Down Expand Up @@ -431,6 +438,7 @@ def test_project_config(self):
tls_ca=None,
tls_ca_file=None,
tls_security=None,
tls_server_name=None,
timeout=10,
command_timeout=None,
server_settings=None,
Expand Down