diff --git a/contributing/BACKENDS.md b/contributing/BACKENDS.md index d4f0392ac..972938560 100644 --- a/contributing/BACKENDS.md +++ b/contributing/BACKENDS.md @@ -97,8 +97,6 @@ Add any dependencies required by your cloud provider to `setup.py`. Create a sep Add a new enumeration member for your provider to `BackendType` (`src/dstack/_internal/core/models/backends/base.py`). Use the name of the provider. -Then create a database [migration](MIGRATIONS.md) to reflect the new enum member. - ##### 2.4.2. Create the backend directory Create a new directory under `src/dstack/_internal/core/backends` with the name of the backend type. diff --git a/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py b/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py new file mode 100644 index 000000000..4690d4a32 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py @@ -0,0 +1,171 @@ +"""Store BackendType as string + +Revision ID: bc8ca4a505c6 +Revises: 98d1b92988bc +Create Date: 2025-03-10 14:49:06.837118 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "bc8ca4a505c6" +down_revision = "98d1b92988bc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("backends", schema=None) as batch_op: + batch_op.alter_column( + "type", + existing_type=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + type_=sa.String(length=100), + existing_nullable=False, + ) + + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "backend", + existing_type=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + type_=sa.String(length=100), + existing_nullable=True, + ) + + sa.Enum( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ).drop(op.get_bind()) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ).create(op.get_bind()) + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "backend", + existing_type=sa.String(length=100), + type_=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + existing_nullable=True, + postgresql_using="backend::VARCHAR::backendtype", + ) + + with op.batch_alter_table("backends", schema=None) as batch_op: + batch_op.alter_column( + "type", + existing_type=sa.String(length=100), + type_=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + existing_nullable=False, + postgresql_using="type::VARCHAR::backendtype", + ) + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index ace2118e7..ed780a11a 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -1,3 +1,4 @@ +import enum import uuid from datetime import datetime from typing import Callable, List, Optional, Union @@ -112,7 +113,11 @@ def set_encrypt_decrypt( cls._encrypt_func = encrypt_func cls._decrypt_func = decrypt_func - def process_bind_param(self, value: Union[DecryptedString, str], dialect): + def process_bind_param( + self, value: Optional[Union[DecryptedString, str]], dialect + ) -> Optional[str]: + if value is None: + return None if isinstance(value, str): # Passing string allows binding an encrypted value directly # e.g. for comparisons @@ -130,6 +135,29 @@ def process_result_value(self, value: Optional[str], dialect) -> Optional[Decryp return DecryptedString(plaintext=None, decrypted=False, exc=e) +class EnumAsString(TypeDecorator): + """ + A custom type decorator that stores enums as strings in the DB. + """ + + impl = String + cache_ok = True + + def __init__(self, enum_class: type[enum.Enum], *args, **kwargs): + self.enum_class = enum_class + super().__init__(*args, **kwargs) + + def process_bind_param(self, value: Optional[enum.Enum], dialect) -> Optional[str]: + if value is None: + return None + return value.name + + def process_result_value(self, value: Optional[str], dialect) -> Optional[enum.Enum]: + if value is None: + return None + return self.enum_class[value] + + constraint_naming_convention = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", @@ -222,7 +250,7 @@ class BackendModel(BaseModel): ) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship() - type: Mapped[BackendType] = mapped_column(Enum(BackendType)) + type: Mapped[BackendType] = mapped_column(EnumAsString(BackendType, 100)) config: Mapped[str] = mapped_column(String(20000)) auth: Mapped[DecryptedString] = mapped_column(EncryptedString(20000)) @@ -533,7 +561,7 @@ class InstanceModel(BaseModel): last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) # backend - backend: Mapped[Optional[BackendType]] = mapped_column(Enum(BackendType)) + backend: Mapped[Optional[BackendType]] = mapped_column(EnumAsString(BackendType, 100)) backend_data: Mapped[Optional[str]] = mapped_column(Text) # offer diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index cf509cd1e..bf4bfc142 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -102,7 +102,7 @@ async def validate_and_create_backend_model( ) return BackendModel( project_id=project.id, - type=configurator.TYPE.value, + type=configurator.TYPE, config=backend_record.config, auth=DecryptedString(plaintext=backend_record.auth), )