Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions contributing/BACKENDS.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ Add any dependencies required by your cloud provider to `setup.py`. Create a sep
Add a new enumeration member for your provider to `BackendType` (`src/dstack/_internal/core/models/backends/base.py`).
Use the name of the provider.

Then create a database [migration](MIGRATIONS.md) to reflect the new enum member.

##### 2.4.2. Create the backend directory

Create a new directory under `src/dstack/_internal/core/backends` with the name of the backend type.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Store BackendType as string

Revision ID: bc8ca4a505c6
Revises: 98d1b92988bc
Create Date: 2025-03-10 14:49:06.837118

"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "bc8ca4a505c6"
down_revision = "98d1b92988bc"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("backends", schema=None) as batch_op:
batch_op.alter_column(
"type",
existing_type=postgresql.ENUM(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
),
type_=sa.String(length=100),
existing_nullable=False,
)

with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.alter_column(
"backend",
existing_type=postgresql.ENUM(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
),
type_=sa.String(length=100),
existing_nullable=True,
)

sa.Enum(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
).drop(op.get_bind())
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
).create(op.get_bind())
with op.batch_alter_table("instances", schema=None) as batch_op:
batch_op.alter_column(
"backend",
existing_type=sa.String(length=100),
type_=postgresql.ENUM(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
),
existing_nullable=True,
postgresql_using="backend::VARCHAR::backendtype",
)

with op.batch_alter_table("backends", schema=None) as batch_op:
batch_op.alter_column(
"type",
existing_type=sa.String(length=100),
type_=postgresql.ENUM(
"AWS",
"AZURE",
"CUDO",
"DATACRUNCH",
"DSTACK",
"GCP",
"KUBERNETES",
"LAMBDA",
"LOCAL",
"REMOTE",
"NEBIUS",
"OCI",
"RUNPOD",
"TENSORDOCK",
"VASTAI",
"VULTR",
name="backendtype",
),
existing_nullable=False,
postgresql_using="type::VARCHAR::backendtype",
)

# ### end Alembic commands ###
34 changes: 31 additions & 3 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import uuid
from datetime import datetime
from typing import Callable, List, Optional, Union
Expand Down Expand Up @@ -112,7 +113,11 @@ def set_encrypt_decrypt(
cls._encrypt_func = encrypt_func
cls._decrypt_func = decrypt_func

def process_bind_param(self, value: Union[DecryptedString, str], dialect):
def process_bind_param(
self, value: Optional[Union[DecryptedString, str]], dialect
) -> Optional[str]:
if value is None:
return None
if isinstance(value, str):
# Passing string allows binding an encrypted value directly
# e.g. for comparisons
Expand All @@ -130,6 +135,29 @@ def process_result_value(self, value: Optional[str], dialect) -> Optional[Decryp
return DecryptedString(plaintext=None, decrypted=False, exc=e)


class EnumAsString(TypeDecorator):
"""
A custom type decorator that stores enums as strings in the DB.
"""

impl = String
cache_ok = True

def __init__(self, enum_class: type[enum.Enum], *args, **kwargs):
self.enum_class = enum_class
super().__init__(*args, **kwargs)

def process_bind_param(self, value: Optional[enum.Enum], dialect) -> Optional[str]:
if value is None:
return None
return value.name

def process_result_value(self, value: Optional[str], dialect) -> Optional[enum.Enum]:
if value is None:
return None
return self.enum_class[value]


constraint_naming_convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
Expand Down Expand Up @@ -222,7 +250,7 @@ class BackendModel(BaseModel):
)
project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
project: Mapped["ProjectModel"] = relationship()
type: Mapped[BackendType] = mapped_column(Enum(BackendType))
type: Mapped[BackendType] = mapped_column(EnumAsString(BackendType, 100))

config: Mapped[str] = mapped_column(String(20000))
auth: Mapped[DecryptedString] = mapped_column(EncryptedString(20000))
Expand Down Expand Up @@ -533,7 +561,7 @@ class InstanceModel(BaseModel):
last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime)

# backend
backend: Mapped[Optional[BackendType]] = mapped_column(Enum(BackendType))
backend: Mapped[Optional[BackendType]] = mapped_column(EnumAsString(BackendType, 100))
backend_data: Mapped[Optional[str]] = mapped_column(Text)

# offer
Expand Down
2 changes: 1 addition & 1 deletion src/dstack/_internal/server/services/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def validate_and_create_backend_model(
)
return BackendModel(
project_id=project.id,
type=configurator.TYPE.value,
type=configurator.TYPE,
config=backend_record.config,
auth=DecryptedString(plaintext=backend_record.auth),
)
Expand Down