Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ class UserTokenCreds(CoreModel):
class UserWithCreds(User):
creds: UserTokenCreds
ssh_private_key: Optional[str] = None


class UserHookConfig(CoreModel):
"""
This class can be inherited to extend the user creation configuration passed to the hooks.
"""

pass
8 changes: 6 additions & 2 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dstack._internal.core.models.users import (
GlobalRole,
User,
UserHookConfig,
UserPermissions,
UserTokenCreds,
UserWithCreds,
Expand Down Expand Up @@ -79,6 +80,7 @@ async def create_user(
email: Optional[str] = None,
active: bool = True,
token: Optional[str] = None,
config: Optional[UserHookConfig] = None,
) -> UserModel:
validate_username(username)
user_model = await get_user_model_by_name(session=session, username=username, ignore_case=True)
Expand All @@ -101,7 +103,7 @@ async def create_user(
session.add(user)
await session.commit()
for func in _CREATE_USER_HOOKS:
await func(session, user)
await func(session, user, config)
return user


Expand Down Expand Up @@ -267,7 +269,9 @@ def is_valid_username(username: str) -> bool:
_CREATE_USER_HOOKS = []


def register_create_user_hook(func: Callable[[AsyncSession, UserModel], Awaitable[None]]):
def register_create_user_hook(
func: Callable[[AsyncSession, UserModel, Optional[UserHookConfig]], Awaitable[None]],
):
_CREATE_USER_HOOKS.append(func)


Expand Down