From 1a1398cf9f3e00fc06719c400bd2d2b0e2f6f5b8 Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Mon, 20 Oct 2025 18:45:13 +0200 Subject: [PATCH] Parallel SSL cert retrieval When using insecure TLS, if the first address of a hostname DNS fails, our code does not try with the other IPs as gRPC does later, so we cannot connect. This patch attempts to retrieve the insecure cert in parallel over all the available IP addresses. (cherry picked from commit 1d426530b1b3dba0e0aaf9419061d54259f6b0b9) --- .../jumpstarter/jumpstarter/common/grpc.py | 139 +++++++++++++++--- 1 file changed, 116 insertions(+), 23 deletions(-) diff --git a/packages/jumpstarter/jumpstarter/common/grpc.py b/packages/jumpstarter/jumpstarter/common/grpc.py index cf0387bc5..212978fef 100644 --- a/packages/jumpstarter/jumpstarter/common/grpc.py +++ b/packages/jumpstarter/jumpstarter/common/grpc.py @@ -1,5 +1,6 @@ import asyncio import base64 +import logging import os import socket import ssl @@ -12,32 +13,124 @@ from jumpstarter.common.exceptions import ConfigurationError, ConnectionError +logger = logging.getLogger(__name__) + + +async def _try_connect_and_extract_cert( + ip_address: str, port: int, ssl_context: ssl.SSLContext, hostname: str, timeout: float +) -> bytes: + """ + Try to connect to a single IP and extract its certificate chain. + + Returns the certificate chain in PEM format as bytes. + Raises exception on failure. + """ + logger.debug(f"Attempting TLS connection to {ip_address}:{port} (timeout={timeout}s)") + _, writer = await asyncio.wait_for( + asyncio.open_connection(ip_address, port, ssl=ssl_context, server_hostname=hostname), + timeout=timeout, + ) + logger.debug(f"Successfully connected to {ip_address}:{port}") + try: + # Extract certificates + cert_chain = writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain() + root_certificates = "" + for cert in cert_chain: + root_certificates += cert.public_bytes() + logger.debug(f"Successfully extracted {len(cert_chain)} certificate(s) from {ip_address}:{port}") + + return root_certificates.encode() + finally: + writer.close() + + +async def _ssl_channel_credentials_insecure(target: str, timeout: float) -> grpc.ChannelCredentials: # noqa: C901 + """ + Extract TLS certificates from server without verification (insecure mode). + + Tries to connect to all resolved IPs in parallel and returns credentials + from the first successful connection. + """ + try: + parsed = urlparse(f"//{target}") + port = parsed.port if parsed.port else 443 + except ValueError as e: + raise ConfigurationError(f"Failed parsing {target}") from e + + try: + with fail_after(timeout): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + # Resolve all IP addresses for the hostname + loop = asyncio.get_running_loop() + addr_info = await loop.getaddrinfo( + parsed.hostname, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + + # Log resolved IPs + resolved_ips = [sockaddr[0] for _, _, _, _, sockaddr in addr_info] + logger.debug( + f"Resolved {parsed.hostname} to {len(resolved_ips)} IP(s): {', '.join(resolved_ips)}" + ) + + # Try all IPs in parallel - race for first success + # Wrap tasks to include IP info with results/exceptions + async def try_with_ip(ip_address: str): + """Wrapper that returns (ip, result) on success or (ip, exception) on failure.""" + try: + result = await _try_connect_and_extract_cert( + ip_address, port, ssl_context, parsed.hostname, timeout + ) + return (ip_address, result, None) + except Exception as e: + return (ip_address, None, e) + + tasks = [] + for _family, _type, _proto, _canonname, sockaddr in addr_info: + ip_address = sockaddr[0] + task = asyncio.create_task(try_with_ip(ip_address)) + tasks.append(task) + + # Process tasks as they complete + errors = {} + + try: + for future in asyncio.as_completed(tasks): + ip_address, root_certificates, error = await future + + if error is None: + # Success! Return immediately (cleanup in finally) + logger.debug(f"Using certificates from {ip_address}:{port}") + return grpc.ssl_channel_credentials(root_certificates=root_certificates) + + # This IP failed - log and continue trying other IPs + if isinstance(error, ssl.SSLError): + logger.error(f"SSL error on {ip_address}:{port}: {error}") + else: + logger.warning(f"Failed to connect to {ip_address}:{port}: {type(error).__name__}: {error}") + errors[ip_address] = error + + # All IPs failed + raise ConnectionError( + f"Failed connecting to {parsed.hostname}:{port} - all IPs exhausted. Errors: {errors}" + ) + finally: + # Cancel any remaining tasks + for task in tasks: + if not task.done(): + task.cancel() + except socket.gaierror as e: + raise ConnectionError(f"Failed resolving {parsed.hostname}") from e + except TimeoutError as e: + raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e + async def ssl_channel_credentials(target: str, tls_config, timeout=5): + """Get SSL channel credentials for gRPC connection.""" if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1": - try: - parsed = urlparse(f"//{target}") - port = parsed.port if parsed.port else 443 - except ValueError as e: - raise ConfigurationError(f"Failed parsing {target}") from e - - try: - with fail_after(timeout): - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - _, writer = await asyncio.open_connection(parsed.hostname, port, ssl=ssl_context) - root_certificates = "" - for cert in writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain(): - root_certificates += cert.public_bytes() - return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode()) - except socket.gaierror as e: - raise ConnectionError(f"Failed resolving {parsed.hostname}") from e - except ConnectionRefusedError as e: - raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e - except TimeoutError as e: - raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e - + return await _ssl_channel_credentials_insecure(target, timeout) elif tls_config.ca != "": ca_certificate = base64.b64decode(tls_config.ca) return grpc.ssl_channel_credentials(ca_certificate)