Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/dstack/_internal/core/services/ssh/key_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Optional

from dstack._internal.core.models.users import UserWithCreds

if TYPE_CHECKING:
from dstack.api.server import APIClient
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put UserSSHKeyManager in core instead of api?


KEY_REFRESH_RATE = timedelta(minutes=10) # redownload the key periodically in case it was rotated


@dataclass
class UserSSHKey:
public_key: str
private_key_path: Path


class UserSSHKeyManager:
def __init__(self, api_client: "APIClient", ssh_keys_dir: Path) -> None:
self._api_client = api_client
self._key_path = ssh_keys_dir / api_client.get_token_hash()
self._pub_key_path = self._key_path.with_suffix(".pub")

def get_user_key(self) -> Optional[UserSSHKey]:
"""
Return the up-to-date user key, or None if the user has no key (if created before 0.19.33)
"""
if (
not self._key_path.exists()
or not self._pub_key_path.exists()
or datetime.now() - datetime.fromtimestamp(self._key_path.stat().st_mtime)
> KEY_REFRESH_RATE
):
if not self._download_user_key():
return None
return UserSSHKey(
public_key=self._pub_key_path.read_text(), private_key_path=self._key_path
)

def _download_user_key(self) -> bool:
user = self._api_client.users.get_my_user()
if not (isinstance(user, UserWithCreds) and user.ssh_public_key and user.ssh_private_key):
return False

def key_opener(path, flags):
return os.open(path, flags, 0o600)

with open(self._key_path, "w", opener=key_opener) as f:
f.write(user.ssh_private_key)
with open(self._pub_key_path, "w") as f:
f.write(user.ssh_public_key)

return True
3 changes: 3 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ class UserModel(BaseModel):
# deactivated users cannot access API
active: Mapped[bool] = mapped_column(Boolean, default=True)

# SSH keys can be null for users created before 0.19.33.
# Keys for those users are being gradually generated on /get_my_user calls.
# TODO: make keys required in a future version.
ssh_private_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
ssh_public_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True)

Expand Down
7 changes: 7 additions & 0 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@ async def list_users(

@router.post("/get_my_user", response_model=UserWithCreds)
async def get_my_user(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
):
if user.ssh_private_key is None or user.ssh_public_key is None:
# Generate keys for pre-0.19.33 users
updated_user = await users.refresh_ssh_key(session=session, user=user, username=user.name)
if updated_user is None:
raise ResourceNotExistsError()
user = updated_user
return CustomORJSONResponse(users.user_model_to_user_with_creds(user))


Expand Down
6 changes: 3 additions & 3 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from dstack._internal.server.models import DecryptedString, UserModel
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.utils.routers import error_forbidden
from dstack._internal.utils import crypto
from dstack._internal.utils.common import run_async
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -88,7 +88,7 @@ async def create_user(
raise ResourceExistsError()
if token is None:
token = str(uuid.uuid4())
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
user = UserModel(
id=uuid.uuid4(),
name=username,
Expand Down Expand Up @@ -135,7 +135,7 @@ async def refresh_ssh_key(
logger.debug("Refreshing SSH key for user [code]%s[/code]", username)
if user.global_role != GlobalRole.ADMIN and user.name != username:
raise error_forbidden()
private_bytes, public_bytes = await run_async(generate_rsa_key_pair_bytes, username)
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
await session.execute(
update(UserModel)
.where(UserModel.name == username)
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ async def create_user(
global_role: GlobalRole = GlobalRole.ADMIN,
token: Optional[str] = None,
email: Optional[str] = None,
ssh_public_key: Optional[str] = None,
ssh_private_key: Optional[str] = None,
active: bool = True,
) -> UserModel:
if token is None:
Expand All @@ -137,6 +139,8 @@ async def create_user(
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
email=email,
ssh_public_key=ssh_public_key,
ssh_private_key=ssh_private_key,
active=active,
)
session.add(user)
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/api/_public/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dstack._internal.utils.path import PathLike as PathLike
from dstack.api._public.backends import BackendCollection
from dstack.api._public.repos import RepoCollection
from dstack.api._public.runs import RunCollection, warn
from dstack.api._public.runs import RunCollection
from dstack.api.server import APIClient

logger = get_logger(__name__)
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
self._backends = BackendCollection(api_client, project_name)
self._runs = RunCollection(api_client, project_name, self)
if ssh_identity_file is not None:
warn(
logger.warning(
"[code]ssh_identity_file[/code] in [code]Client[/code] is deprecated and ignored; will be removed"
" since 0.19.40"
)
Expand Down
75 changes: 36 additions & 39 deletions src/dstack/api/_public/runs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import base64
import hashlib
import os
import queue
import tempfile
import threading
Expand All @@ -17,7 +15,6 @@
from websocket import WebSocketApp

import dstack.api as api
from dstack._internal.cli.utils.common import warn
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.errors import ClientError, ConfigurationError, ResourceNotExistsError
from dstack._internal.core.models.backends.base import BackendType
Expand Down Expand Up @@ -48,10 +45,10 @@
get_service_port,
)
from dstack._internal.core.models.runs import Run as RunModel
from dstack._internal.core.models.users import UserWithCreds
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.core.services.logs import URLReplacer
from dstack._internal.core.services.ssh.attach import SSHAttach
from dstack._internal.core.services.ssh.key_manager import UserSSHKeyManager
from dstack._internal.core.services.ssh.ports import PortsLock
from dstack._internal.server.schemas.logs import PollLogsRequest
from dstack._internal.utils.common import get_or_error, make_proxy_url
Expand Down Expand Up @@ -88,7 +85,7 @@ def __init__(
self._ports_lock: Optional[PortsLock] = ports_lock
self._ssh_attach: Optional[SSHAttach] = None
if ssh_identity_file is not None:
warn(
logger.warning(
"[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed"
" since 0.19.40"
)
Expand Down Expand Up @@ -281,31 +278,20 @@ def attach(
dstack.api.PortUsedError: If ports are in use or the run is attached by another process.
"""
if not ssh_identity_file:
user = self._api_client.users.get_my_user()
run_ssh_key_pub = self._run.run_spec.ssh_key_pub
config_manager = ConfigManager()
if isinstance(user, UserWithCreds) and user.ssh_public_key == run_ssh_key_pub:
token_hash = hashlib.sha1(user.creds.token.encode()).hexdigest()[:8]
config_manager.dstack_ssh_dir.mkdir(parents=True, exist_ok=True)
ssh_identity_file = config_manager.dstack_ssh_dir / token_hash

def key_opener(path, flags):
return os.open(path, flags, 0o600)

with open(ssh_identity_file, "wb", opener=key_opener) as f:
assert user.ssh_private_key
f.write(user.ssh_private_key.encode())
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if (
user_key := key_manager.get_user_key()
) and user_key.public_key == self._run.run_spec.ssh_key_pub:
ssh_identity_file = user_key.private_key_path
else:
if config_manager.dstack_key_path.exists():
# TODO: Remove since 0.19.40
warn(
f"Using legacy [code]{config_manager.dstack_key_path}[/code]."
" Future versions will use the user SSH key from the server.",
)
logger.debug(f"Using legacy [code]{config_manager.dstack_key_path}[/code].")
ssh_identity_file = config_manager.dstack_key_path
else:
raise ConfigurationError(
f"User SSH key doen't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist"
f"User SSH key doesn't match; default SSH key ({config_manager.dstack_key_path}) doesn't exist"
)
ssh_identity_file = str(ssh_identity_file)

Expand Down Expand Up @@ -504,15 +490,19 @@ def get_run_plan(
ssh_key_pub = Path(ssh_identity_file).with_suffix(".pub").read_text()
else:
config_manager = ConfigManager()
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
warn(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" Future versions will use the user SSH key from the server.",
)
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
# TODO: Uncomment after 0.19.40
# ssh_key_pub = None
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if key_manager.get_user_key():
ssh_key_pub = None # using the server-managed user key
else:
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
logger.warning(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" You will only be able to attach to the run from this client."
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
" automatically replicated to all clients.",
)
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
run_spec = RunSpec(
run_name=configuration.name,
repo_id=repo.repo_id,
Expand Down Expand Up @@ -760,12 +750,19 @@ def get_plan(
idle_duration=idle_duration, # type: ignore[assignment]
)
config_manager = ConfigManager()
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
warn(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" Future versions will use the user SSH key from the server.",
)
key_manager = UserSSHKeyManager(self._api_client, config_manager.dstack_ssh_dir)
if key_manager.get_user_key():
ssh_key_pub = None # using the server-managed user key
else:
if not config_manager.dstack_key_path.exists():
generate_rsa_key_pair(private_key_path=config_manager.dstack_key_path)
logger.warning(
f"Using legacy [code]{config_manager.dstack_key_path.with_suffix('.pub')}[/code]."
" You will only be able to attach to the run from this client."
" Update the [code]dstack[/] server to [code]0.19.34[/]+ to switch to user keys"
" automatically replicated to all clients.",
)
ssh_key_pub = config_manager.dstack_key_path.with_suffix(".pub").read_text()
run_spec = RunSpec(
run_name=run_name,
repo_id=repo.repo_id,
Expand All @@ -775,7 +772,7 @@ def get_plan(
configuration_path=configuration_path,
configuration=configuration,
profile=profile,
ssh_key_pub=config_manager.dstack_key_path.with_suffix(".pub").read_text(),
ssh_key_pub=ssh_key_pub,
)
logger.debug("Getting run plan")
run_plan = self._api_client.runs.get_plan(self._project, run_spec)
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/api/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import os
import pprint
import time
Expand Down Expand Up @@ -121,6 +122,9 @@ def volumes(self) -> VolumesAPIClient:
def files(self) -> FilesAPIClient:
return FilesAPIClient(self._request, self._logger)

def get_token_hash(self) -> str:
return hashlib.sha1(self._token.encode()).hexdigest()[:8]

def _request(
self,
path: str,
Expand Down
Loading