Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions packages/jumpstarter/jumpstarter/common/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from urllib.parse import urlparse

import grpc
from anyio import fail_after
from anyio.to_thread import run_sync

from jumpstarter.common.exceptions import ConfigurationError, ConnectionError


def ssl_channel_credentials(target: str, tls_config):
async def ssl_channel_credentials(target: str, tls_config, timeout=5):
configure_grpc_env()
if tls_config.insecure or os.getenv("JUMPSTARTER_GRPC_INSECURE") == "1":
try:
Expand All @@ -21,12 +23,15 @@ def ssl_channel_credentials(target: str, tls_config):
raise ConfigurationError(f"Failed parsing {target}") from e

try:
root_certificates = ssl.get_server_certificate((parsed.hostname, port))
with fail_after(timeout):
root_certificates = await run_sync(ssl.get_server_certificate, (parsed.hostname, port))
return grpc.ssl_channel_credentials(root_certificates=root_certificates.encode())
except socket.gaierror as e:
raise ConnectionError(f"Failed resolving {parsed.hostname}") from e
except ConnectionRefusedError as e:
raise ConnectionError(f"Failed connecting to {parsed.hostname}:{port}") from e
except TimeoutError as e:
raise ConnectionError(f"Timeout connecting to {parsed.hostname}:{port}") from e

elif tls_config.ca != "":
ca_certificate = base64.b64decode(tls_config.ca)
Expand Down
2 changes: 1 addition & 1 deletion packages/jumpstarter/jumpstarter/common/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class StreamRequestMetadata(BaseModel):
@asynccontextmanager
async def connect_router_stream(endpoint, token, stream, tls_config, grpc_options):
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(endpoint, tls_config),
await ssl_channel_credentials(endpoint, tls_config),
grpc.access_token_call_credentials(token),
)

Expand Down
2 changes: 1 addition & 1 deletion packages/jumpstarter/jumpstarter/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ClientConfigV1Alpha1(BaseModel):

async def channel(self):
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(self.endpoint, self.tls),
await ssl_channel_credentials(self.endpoint, self.tls),
call_credentials("Client", self.metadata, self.token),
)

Expand Down
4 changes: 2 additions & 2 deletions packages/jumpstarter/jumpstarter/config/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ async def serve(self):
# dynamic import to avoid circular imports
from jumpstarter.exporter import Exporter

def channel_factory():
async def channel_factory():
credentials = grpc.composite_channel_credentials(
ssl_channel_credentials(self.endpoint, self.tls),
await ssl_channel_credentials(self.endpoint, self.tls),
call_credentials("Exporter", self.metadata, self.token),
)
return aio_secure_channel(self.endpoint, credentials, self.grpcOptions)
Expand Down
8 changes: 4 additions & 4 deletions packages/jumpstarter/jumpstarter/exporter/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Exporter(AbstractAsyncContextManager, Metadata):
grpc_options: dict[str, str] = field(default_factory=dict)

async def __aexit__(self, exc_type, exc_value, traceback):
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
logger.info("Unregistering exporter with controller")
await controller.Unregister(
jumpstarter_pb2.UnregisterRequest(
Expand All @@ -47,7 +47,7 @@ async def __handle(self, path, endpoint, token, tls_config, grpc_options):

@asynccontextmanager
async def session(self):
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
with Session(
uuid=self.uuid,
labels=self.labels,
Expand Down Expand Up @@ -76,7 +76,7 @@ async def listen(retries=5, backoff=3):
retries_left = retries
while True:
try:
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
async for request in controller.Listen(jumpstarter_pb2.ListenRequest(lease_name=lease_name)):
await listen_tx.send(request)
except Exception as e:
Expand Down Expand Up @@ -113,7 +113,7 @@ async def status(retries=5, backoff=3):
retries_left = retries
while True:
try:
controller = jumpstarter_pb2_grpc.ControllerServiceStub(self.channel_factory())
controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory())
async for status in controller.Status(jumpstarter_pb2.StatusRequest()):
await status_tx.send(status)
except Exception as e:
Expand Down