Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 116 additions & 23 deletions packages/jumpstarter/jumpstarter/common/grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import logging
import os
import socket
import ssl
Expand All @@ -12,32 +13,124 @@

from jumpstarter.common.exceptions import ConfigurationError, ConnectionError

logger = logging.getLogger(__name__)


async def _try_connect_and_extract_cert(
ip_address: str, port: int, ssl_context: ssl.SSLContext, hostname: str, timeout: float
) -> bytes:
"""
Try to connect to a single IP and extract its certificate chain.

Returns the certificate chain in PEM format as bytes.
Raises exception on failure.
"""
logger.debug(f"Attempting TLS connection to {ip_address}:{port} (timeout={timeout}s)")
_, writer = await asyncio.wait_for(
asyncio.open_connection(ip_address, port, ssl=ssl_context, server_hostname=hostname),
timeout=timeout,
)
logger.debug(f"Successfully connected to {ip_address}:{port}")
try:
# Extract certificates
cert_chain = writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain()
root_certificates = ""
for cert in cert_chain:
root_certificates += cert.public_bytes()
logger.debug(f"Successfully extracted {len(cert_chain)} certificate(s) from {ip_address}:{port}")

return root_certificates.encode()
finally:
writer.close()


async def _ssl_channel_credentials_insecure(target: str, timeout: float) -> grpc.ChannelCredentials: # noqa: C901
"""
Extract TLS certificates from server without verification (insecure mode).

Tries to connect to all resolved IPs in parallel and returns credentials
from the first successful connection.
"""
try:
parsed = urlparse(f"//{target}")
port = parsed.port if parsed.port else 443
except ValueError as e:
raise ConfigurationError(f"Failed parsing {target}") from e

try:
with fail_after(timeout):
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE

# Resolve all IP addresses for the hostname
loop = asyncio.get_running_loop()
addr_info = await loop.getaddrinfo(
parsed.hostname, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)

# Log resolved IPs
resolved_ips = [sockaddr[0] for _, _, _, _, sockaddr in addr_info]
logger.debug(
f"Resolved {parsed.hostname} to {len(resolved_ips)} IP(s): {', '.join(resolved_ips)}"
)

# Try all IPs in parallel - race for first success
# Wrap tasks to include IP info with results/exceptions
async def try_with_ip(ip_address: str):
"""Wrapper that returns (ip, result) on success or (ip, exception) on failure."""
try:
result = await _try_connect_and_extract_cert(
ip_address, port, ssl_context, parsed.hostname, timeout
)
return (ip_address, result, None)
except Exception as e:
return (ip_address, None, e)

tasks = []
for _family, _type, _proto, _canonname, sockaddr in addr_info:
ip_address = sockaddr[0]
task = asyncio.create_task(try_with_ip(ip_address))
tasks.append(task)

# Process tasks as they complete
errors = {}

try:
for future in asyncio.as_completed(tasks):
ip_address, root_certificates, error = await future

if error is None:
# Success! Return immediately (cleanup in finally)
logger.debug(f"Using certificates from {ip_address}:{port}")
return grpc.ssl_channel_credentials(root_certificates=root_certificates)

# This IP failed - log and continue trying other IPs
if isinstance(error, ssl.SSLError):
logger.error(f"SSL error on {ip_address}:{port}: {error}")
else:
logger.warning(f"Failed to connect to {ip_address}:{port}: {type(error).__name__}: {error}")
errors[ip_address] = error

# All IPs failed
raise ConnectionError(
f"Failed connecting to {parsed.hostname}:{port} - all IPs exhausted. Errors: {errors}"
)
finally:
# Cancel any remaining tasks
for task in tasks:
if not task.done():
task.cancel()
except socket.gaierror as e:
raise ConnectionError(f"Failed resolving {parsed.hostname}") from e
except TimeoutError as e:
raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e


async def ssl_channel_credentials(target: str, tls_config, timeout=5):
"""Get SSL channel credentials for gRPC connection."""
if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1":
try:
parsed = urlparse(f"//{target}")
port = parsed.port if parsed.port else 443
except ValueError as e:
raise ConfigurationError(f"Failed parsing {target}") from e

try:
with fail_after(timeout):
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
_, writer = await asyncio.open_connection(parsed.hostname, port, ssl=ssl_context)
root_certificates = ""
for cert in writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain():
root_certificates += cert.public_bytes()
return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode())
except socket.gaierror as e:
raise ConnectionError(f"Failed resolving {parsed.hostname}") from e
except ConnectionRefusedError as e:
raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e
except TimeoutError as e:
raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e

return await _ssl_channel_credentials_insecure(target, timeout)
elif tls_config.ca != "":
ca_certificate = base64.b64decode(tls_config.ca)
return grpc.ssl_channel_credentials(ca_certificate)
Expand Down
Loading