From 5de229821dbc224b2944080c0593055a330889d7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 15:10:02 +0500 Subject: [PATCH 1/3] Store BackendType as EnumAsString --- ...c8ca4a505c6_store_backendtype_as_string.py | 171 ++++++++++++++++++ src/dstack/_internal/server/models.py | 34 +++- .../server/services/backends/__init__.py | 2 +- 3 files changed, 203 insertions(+), 4 deletions(-) create mode 100644 src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py diff --git a/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py b/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py new file mode 100644 index 000000000..4690d4a32 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/bc8ca4a505c6_store_backendtype_as_string.py @@ -0,0 +1,171 @@ +"""Store BackendType as string + +Revision ID: bc8ca4a505c6 +Revises: 98d1b92988bc +Create Date: 2025-03-10 14:49:06.837118 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "bc8ca4a505c6" +down_revision = "98d1b92988bc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("backends", schema=None) as batch_op: + batch_op.alter_column( + "type", + existing_type=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + type_=sa.String(length=100), + existing_nullable=False, + ) + + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "backend", + existing_type=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + type_=sa.String(length=100), + existing_nullable=True, + ) + + sa.Enum( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ).drop(op.get_bind()) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ).create(op.get_bind()) + with op.batch_alter_table("instances", schema=None) as batch_op: + batch_op.alter_column( + "backend", + existing_type=sa.String(length=100), + type_=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + existing_nullable=True, + postgresql_using="backend::VARCHAR::backendtype", + ) + + with op.batch_alter_table("backends", schema=None) as batch_op: + batch_op.alter_column( + "type", + existing_type=sa.String(length=100), + type_=postgresql.ENUM( + "AWS", + "AZURE", + "CUDO", + "DATACRUNCH", + "DSTACK", + "GCP", + "KUBERNETES", + "LAMBDA", + "LOCAL", + "REMOTE", + "NEBIUS", + "OCI", + "RUNPOD", + "TENSORDOCK", + "VASTAI", + "VULTR", + name="backendtype", + ), + existing_nullable=False, + postgresql_using="type::VARCHAR::backendtype", + ) + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index ace2118e7..2d05a8519 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -1,3 +1,4 @@ +import enum import uuid from datetime import datetime from typing import Callable, List, Optional, Union @@ -112,7 +113,11 @@ def set_encrypt_decrypt( cls._encrypt_func = encrypt_func cls._decrypt_func = decrypt_func - def process_bind_param(self, value: Union[DecryptedString, str], dialect): + def process_bind_param( + self, value: Optional[Union[DecryptedString, str]], dialect + ) -> Optional[str]: + if value is None: + return None if isinstance(value, str): # Passing string allows binding an encrypted value directly # e.g. for comparisons @@ -130,6 +135,29 @@ def process_result_value(self, value: Optional[str], dialect) -> Optional[Decryp return DecryptedString(plaintext=None, decrypted=False, exc=e) +class EnumAsString(TypeDecorator): + """ + A custom type decorator that stores enums as strings in the DB. + """ + + impl = String + cache_ok = True + + def __init__(self, enum_class: type[enum.Enum], *args, **kwargs): + self.enum_class = enum_class + super().__init__(*args, **kwargs) + + def process_bind_param(self, value: Optional[enum.Enum], dialect) -> Optional[str]: + if value is None: + return None + return value.name + + def process_result_value(self, value: str, dialect) -> Optional[enum.Enum]: + if value is None: + return None + return self.enum_class[value] + + constraint_naming_convention = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", @@ -222,7 +250,7 @@ class BackendModel(BaseModel): ) project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship() - type: Mapped[BackendType] = mapped_column(Enum(BackendType)) + type: Mapped[BackendType] = mapped_column(EnumAsString(BackendType, 100)) config: Mapped[str] = mapped_column(String(20000)) auth: Mapped[DecryptedString] = mapped_column(EncryptedString(20000)) @@ -533,7 +561,7 @@ class InstanceModel(BaseModel): last_termination_retry_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) # backend - backend: Mapped[Optional[BackendType]] = mapped_column(Enum(BackendType)) + backend: Mapped[Optional[BackendType]] = mapped_column(EnumAsString(BackendType, 100)) backend_data: Mapped[Optional[str]] = mapped_column(Text) # offer diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index cf509cd1e..bf4bfc142 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -102,7 +102,7 @@ async def validate_and_create_backend_model( ) return BackendModel( project_id=project.id, - type=configurator.TYPE.value, + type=configurator.TYPE, config=backend_record.config, auth=DecryptedString(plaintext=backend_record.auth), ) From f05c3c6c0d4c717627a764e9031429645576c7b7 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 15:40:21 +0500 Subject: [PATCH 2/3] Update backend guide --- contributing/BACKENDS.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/contributing/BACKENDS.md b/contributing/BACKENDS.md index d4f0392ac..972938560 100644 --- a/contributing/BACKENDS.md +++ b/contributing/BACKENDS.md @@ -97,8 +97,6 @@ Add any dependencies required by your cloud provider to `setup.py`. Create a sep Add a new enumeration member for your provider to `BackendType` (`src/dstack/_internal/core/models/backends/base.py`). Use the name of the provider. -Then create a database [migration](MIGRATIONS.md) to reflect the new enum member. - ##### 2.4.2. Create the backend directory Create a new directory under `src/dstack/_internal/core/backends` with the name of the backend type. From 25325c546468d12b116ac0109f8eaba971717d22 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 10 Mar 2025 15:48:22 +0500 Subject: [PATCH 3/3] Fix type annotation --- src/dstack/_internal/server/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 2d05a8519..ed780a11a 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -152,7 +152,7 @@ def process_bind_param(self, value: Optional[enum.Enum], dialect) -> Optional[st return None return value.name - def process_result_value(self, value: str, dialect) -> Optional[enum.Enum]: + def process_result_value(self, value: Optional[str], dialect) -> Optional[enum.Enum]: if value is None: return None return self.enum_class[value]