diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index d0ad4ac64..91842665b 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -12,6 +12,7 @@ from dstack._internal.core.services.ssh.client import get_ssh_client_info from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.core.services.ssh.tunnel import SSHTunnel, ports_to_forwarded_sockets +from dstack._internal.utils.logging import get_logger from dstack._internal.utils.path import FilePath, PathLike from dstack._internal.utils.ssh import ( default_ssh_config_path, @@ -21,6 +22,8 @@ update_ssh_config, ) +logger = get_logger(__name__) + # ssh -L option format: [bind_address:]port:host:hostport _SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P\d+):localhost:(?P\d+)") @@ -68,6 +71,7 @@ def __init__( local_backend: bool = False, bind_address: Optional[str] = None, ): + self._attached = False self._ports_lock = ports_lock self.ports = ports_lock.dict() self.run_name = run_name @@ -209,6 +213,7 @@ def attach(self): for i in range(max_retries): try: self.tunnel.open() + self._attached = True atexit.register(self.detach) break except SSHError: @@ -219,9 +224,14 @@ def attach(self): raise SSHError("Can't connect to the remote host") def detach(self): + if not self._attached: + logger.debug("Not attached") + return self.tunnel.close() for host in self.hosts: update_ssh_config(self.ssh_config_path, host, {}) + self._attached = False + logger.debug("Detached") def __enter__(self): self.attach() diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index abdb2ba52..e4f7f276e 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -204,6 +204,11 @@ async def aopen(self) -> None: raise get_ssh_error(stderr) def close(self) -> None: + if not os.path.exists(self.control_sock_path): + logger.debug( + "Control socket does not exist, it seems that ssh process has already exited" + ) + return proc = subprocess.run( self.close_command(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT ) @@ -215,6 +220,11 @@ def close(self) -> None: ) async def aclose(self) -> None: + if not os.path.exists(self.control_sock_path): + logger.debug( + "Control socket does not exist, it seems that ssh process has already exited" + ) + return proc = await asyncio.create_subprocess_exec( *self.close_command(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT )