From f669f4ef90ac77c924d624b89456e187eaaf1cba Mon Sep 17 00:00:00 2001 From: Daniel Ng Date: Mon, 11 May 2026 13:41:44 -0700 Subject: [PATCH] Internal Changes PiperOrigin-RevId: 913861453 --- .../experimental/tiering_service/db_schema.py | 411 ++++++++++++ .../tiering_service/db_schema_test.py | 599 ++++++++++++++++++ checkpoint/pyproject.toml | 4 + 3 files changed, 1014 insertions(+) create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py create mode 100644 checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py new file mode 100644 index 000000000..4d6f717cf --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py @@ -0,0 +1,411 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint Tiering Service (CTS) database schema definition. + +Provides SQLAlchemy models for tracking assets, tier paths, and job queues. +""" + +import enum +import itertools +import uuid + +import sqlalchemy.orm + +Base = sqlalchemy.orm.declarative_base() + + +class AssetState(enum.IntEnum): + """The lifecycle state of an asset tracked by CTS.""" + + ASSET_STATE_UNSPECIFIED = 0 + ASSET_STATE_ACTIVE_WRITE = 1 + ASSET_STATE_STORED = 2 + ASSET_STATE_DELETED = 3 + ASSET_STATE_INCOMPLETE = 4 + + +class BackendType(enum.IntEnum): + """The storage backend type for a tier path.""" + + BACKEND_TYPE_UNSPECIFIED = 0 + BACKEND_TYPE_LUSTRE = 1 + BACKEND_TYPE_GCS = 2 + + +class JobStatus(enum.IntEnum): + """The execution status of an asset job.""" + + JOB_STATUS_UNSPECIFIED = 0 + JOB_STATUS_QUEUED = 1 + JOB_STATUS_PROCESSING = 2 + JOB_STATUS_COMPLETED = 3 + JOB_STATUS_FAILED = 4 + + +class RequestType(enum.IntEnum): + """The operation type requested for an asset job.""" + + REQUEST_TYPE_UNSPECIFIED = 0 + REQUEST_TYPE_COPY = 1 + REQUEST_TYPE_DELETE_FROM_INSTANCE = 2 + REQUEST_TYPE_DELETE_FROM_ALL_TIERS = 3 + + +class Asset(Base): + """A CTS asset representing a complete checkpoint. + + Acts as the primary entity holding assets' metadata and latest storage state. + Unique asset paths are expected to be unique within the active/stored states. + Duplicates are allowed for deleted or incomplete states. + + Attributes: + asset_uuid: A unique identifier for the asset (Primary Key). + path: The user-defined path identifying the asset. + user: The user who owns or created the asset. + tags: Optional JSON field for storing arbitrary tags. + state: The current lifecycle state of the asset, an AssetState enum. + created_at: Timestamp when the asset record was created. + finalized_at: Timestamp when the asset was marked as finalized. + deleted_at: Timestamp when the asset was marked as deleted. + updated_at: Timestamp of the last update to the asset record. + tier_paths: A relationship to the TierPath objects associated with this + asset. + jobs: A relationship to the AssetJob objects associated with this asset. + """ + + __tablename__ = "assets" + + asset_uuid = sqlalchemy.Column( + sqlalchemy.String, + primary_key=True, + default=lambda: str(uuid.uuid4()), + ) + path = sqlalchemy.Column(sqlalchemy.String, index=True, nullable=False) + user = sqlalchemy.Column(sqlalchemy.String, nullable=False) + tags = sqlalchemy.Column(sqlalchemy.JSON, nullable=True) + state = sqlalchemy.Column( + sqlalchemy.Enum(AssetState), default=AssetState.ASSET_STATE_UNSPECIFIED + ) + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, + server_default=sqlalchemy.sql.func.now(), + nullable=False, + ) + finalized_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + deleted_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + updated_at = sqlalchemy.Column( + sqlalchemy.DateTime, + server_default=sqlalchemy.sql.func.now(), + onupdate=sqlalchemy.sql.func.now(), + nullable=False, + ) + + tier_paths = sqlalchemy.orm.relationship( + "TierPath", back_populates="asset", cascade="all, delete-orphan" + ) + jobs = sqlalchemy.orm.relationship( + "AssetJob", back_populates="asset", cascade="all, delete-orphan" + ) + + __table_args__ = ( + # Enforce path only for live assets (ACTIVE_WRITE, STORED). + # Duplicates are allowed for DELETED or INCOMPLETE states. + sqlalchemy.Index( + "idx_assets_unique_path_active_stored", + "path", + unique=True, + sqlite_where=sqlalchemy.column("state").in_([ + AssetState.ASSET_STATE_ACTIVE_WRITE.name, + AssetState.ASSET_STATE_STORED.name, + ]), + postgresql_where=sqlalchemy.column("state").in_([ + AssetState.ASSET_STATE_ACTIVE_WRITE.name, + AssetState.ASSET_STATE_STORED.name, + ]), + ), + ) + + def __repr__(self): + return ( + f"Asset(asset_uuid={self.asset_uuid!r}," + f" path={self.path!r}, state={self.state.name!r}," + f" user={self.user!r})" + ) + + +class StorageBackend(Base): + """A system-wide available storage instance. + + The table should be populated once by a single CTS server during + initialization in one transaction. Afterwards, the content is only validated + against the server configuration. + + Attributes: + id: Primary key for the storage backend entry. + level: An integer representing the tiering level. + zone: The zone where the storage backend resides. + region: The region where the storage backend resides. + multi_regions: A list of regions forming a multi-region deployment. + backend_type: The type of storage (e.g., Lustre, GCS). + tier_paths: Relationship to the TierPath objects utilizing this backend. + """ + + __tablename__ = "storage_backends" + + id = sqlalchemy.Column( + sqlalchemy.Integer, primary_key=True, autoincrement=True + ) + level = sqlalchemy.Column(sqlalchemy.Integer, nullable=False) + zone = sqlalchemy.Column(sqlalchemy.String, nullable=True) + region = sqlalchemy.Column(sqlalchemy.String, nullable=True) + multi_regions = sqlalchemy.Column(sqlalchemy.JSON, nullable=True) + backend_type = sqlalchemy.Column( + sqlalchemy.Enum(BackendType), default=BackendType.BACKEND_TYPE_UNSPECIFIED + ) + + tier_paths = sqlalchemy.orm.relationship( + "TierPath", back_populates="storage_backend", cascade="all, delete-orphan" + ) + + __table_args__ = ( + # Enforce that only one of zone, region, or multi_regions is set. + sqlalchemy.CheckConstraint( + "(CASE WHEN zone IS NOT NULL THEN 1 ELSE 0 END + " + "CASE WHEN region IS NOT NULL THEN 1 ELSE 0 END + " + "CASE WHEN multi_regions IS NOT NULL THEN 1 ELSE 0 END) = 1", + name="check_mutually_exclusive_locations", + ), + ) + + def __repr__(self): + if self.zone: + location = f"zone={self.zone!r}" + elif self.region: + location = f"region={self.region!r}" + elif self.multi_regions: + location = f"multi_regions={self.multi_regions!r}" + else: + location = "None" + return ( + f"StorageBackend(id={self.id}, level={self.level}, " + f"backend_type={self.backend_type.name!r}, {location})" + ) + + def validate_pre_commit(self) -> None: + """Validates StorageBackend constraints before a commit. + + This validates for: + 1. All StorageBackend entries at the same `level` must share the same + `backend_type`. + 2. Within the same `level`, each location identifier (`zone`, `region`, or + `multi_regions`) must be unique across all StorageBackend entries. + + The validation is performed against other StorageBackend objects currently + loaded or newly added within the same SQLAlchemy session. + + Raises: + ValueError: If any of the validation constraints are violated. + """ + session = sqlalchemy.orm.object_session(self) + if session is None: + # No session, so no need to validate. + raise ValueError("No session found") + + session_backends = [ + obj + for obj in set( + itertools.chain(session.new, session.identity_map.values()) + ) + if isinstance(obj, StorageBackend) and obj.level == self.level + ] + types = {b.backend_type for b in session_backends} + if len(types) > 1: + raise ValueError( + f"StorageBackend at level {self.level} must have the same" + f" backend_type, but found conflicting types: {types}" + ) + + seen_zones = set() + seen_regions = set() + seen_multis = set() + for b in session_backends: + if b.zone: + if b.zone in seen_zones: + raise ValueError(f"Duplicate zone[{b.zone}]") + seen_zones.add(b.zone) + if b.region: + if b.region in seen_regions: + raise ValueError(f"Duplicate region[{b.region}]") + seen_regions.add(b.region) + if b.multi_regions: + sorted_list = ( + sorted(b.multi_regions) + if isinstance(b.multi_regions, list) + else b.multi_regions + ) + mr_val = ( + tuple(sorted_list) if isinstance(sorted_list, list) else sorted_list + ) + if mr_val in seen_multis: + raise ValueError(f"Duplicate multi_regions[{mr_val}]") + seen_multis.add(mr_val) + + +@sqlalchemy.event.listens_for(StorageBackend, "before_insert") +@sqlalchemy.event.listens_for(StorageBackend, "before_update") +def _validate_storage_backend_before_flush( + mapper: sqlalchemy.orm.Mapper, + connection: sqlalchemy.engine.Connection, + target: StorageBackend, +) -> None: + del mapper, connection + target.validate_pre_commit() + + +class TierPath(Base): + """A storage location for an asset. + + Asset can be stored in multiple locations across different zones and regions, + and different storage tiers. + + Attributes: + id: Primary key for the tier path. + asset_uuid: Foreign key linking to the `Asset`. + storage_backend_id: Foreign key linking to the `StorageBackend`. + path: The concrete storage path (e.g., GCS URI, Lustre path). + ready_at: Timestamp when the asset became available at this tier path. + expires_at: Timestamp when the asset is scheduled to expire from this tier + path. + asset: SQLAlchemy relationship to the `Asset` object. + storage_backend: SQLAlchemy relationship to the `StorageBackend` object. + """ + + __tablename__ = "tier_paths" + + id = sqlalchemy.Column( + sqlalchemy.Integer, primary_key=True, autoincrement=True + ) + asset_uuid = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey("assets.asset_uuid", ondelete="CASCADE"), + nullable=False, + ) + storage_backend_id = sqlalchemy.Column( + sqlalchemy.Integer, + sqlalchemy.ForeignKey("storage_backends.id", ondelete="CASCADE"), + nullable=False, + ) + path = sqlalchemy.Column(sqlalchemy.String, nullable=False) + ready_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + expires_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + + asset = sqlalchemy.orm.relationship("Asset", back_populates="tier_paths") + storage_backend = sqlalchemy.orm.relationship( + "StorageBackend", back_populates="tier_paths" + ) + + __table_args__ = ( + # An asset can have at most one TierPath for a given storage backend. + sqlalchemy.UniqueConstraint( + "asset_uuid", + "storage_backend_id", + name="uq_tier_path_asset_backend", + ), + ) + + def __repr__(self): + return ( + f"TierPath(id={self.id}, asset_uuid='{self.asset_uuid}'," + f" storage_backend_id={self.storage_backend_id}, path='{self.path}'," + f" ready_at={self.ready_at}, expires_at={self.expires_at})" + ) + + +class AssetJob(Base): + """A queued operation for an Asset. + + This table ensures that multiple servers can try to queue and update job + status without race conditions. + + Attributes: + id: Primary key for the job. + asset_uuid: Foreign key to the target Asset. + request_type: The requested operation type, an instance of RequestType. + status: Current execution status of the job, an instance of JobStatus. + target_tier_path_id: Foreign key to the targeted TierPath for operations + such as COPY or DELETE_FROM_INSTANCE. + created_at: Timestamp when the job was created. + completed_at: Timestamp when the job was completed. + asset: Relationship to the associated Asset. + target_tier_path: Relationship to the targeted TierPath. + """ + + __tablename__ = "asset_jobs" + + id = sqlalchemy.Column( + sqlalchemy.Integer, primary_key=True, autoincrement=True + ) + asset_uuid = sqlalchemy.Column( + sqlalchemy.String, + sqlalchemy.ForeignKey("assets.asset_uuid", ondelete="CASCADE"), + nullable=False, + ) + request_type = sqlalchemy.Column( + sqlalchemy.Enum(RequestType), + default=RequestType.REQUEST_TYPE_UNSPECIFIED, + nullable=False, + ) + status = sqlalchemy.Column( + sqlalchemy.Enum(JobStatus), + default=JobStatus.JOB_STATUS_QUEUED, + index=True, + ) + # Target tier path for COPY and DELETE_FROM_INSTANCE requests + target_tier_path_id = sqlalchemy.Column( + sqlalchemy.Integer, + sqlalchemy.ForeignKey("tier_paths.id", ondelete="CASCADE"), + nullable=True, + ) + + created_at = sqlalchemy.Column( + sqlalchemy.DateTime, + server_default=sqlalchemy.sql.func.now(), + nullable=False, + ) + completed_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=True) + + asset = sqlalchemy.orm.relationship("Asset", back_populates="jobs") + target_tier_path = sqlalchemy.orm.relationship("TierPath") + + __table_args__ = ( + # target_tier_path is required in COPY and DELETE_FROM_INSTANCE requests. + sqlalchemy.CheckConstraint( + """ + (request_type IN ('REQUEST_TYPE_COPY', 'REQUEST_TYPE_DELETE_FROM_INSTANCE') AND target_tier_path_id IS NOT NULL) + OR + (request_type IN ('REQUEST_TYPE_DELETE_FROM_ALL_TIERS', 'REQUEST_TYPE_UNSPECIFIED') AND target_tier_path_id IS NULL) + """, + name="check_asset_job_valid_payload", + ), + ) + + def __repr__(self): + return ( + f"AssetJob(id={self.id}, asset_uuid='{self.asset_uuid}'," + f" request_type='{self.request_type.name}'," + f" status='{self.status.name}'," + f" target_tier_path_id={self.target_tier_path_id}," + f" created_at={self.created_at}, completed_at={self.completed_at})" + ) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py new file mode 100644 index 000000000..2942e31f5 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py @@ -0,0 +1,599 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import multiprocessing +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import aiosqlite # pylint: disable=unused-import +import greenlet # pylint: disable=unused-import +from orbax.checkpoint.experimental.tiering_service import db_schema +import sqlalchemy +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.future import select +from sqlalchemy.orm import sessionmaker + + +class DbSchemaTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + tmp_file = self.create_tempfile() + self.db_path = tmp_file.full_path + + self.engine = create_async_engine( + f"sqlite+aiosqlite:///{self.db_path}", + echo=True, + ) + + async with self.engine.begin() as conn: + await conn.exec_driver_sql("PRAGMA journal_mode=WAL") + await conn.run_sync(db_schema.Base.metadata.create_all) + + self.session_maker = sessionmaker( + self.engine, expire_on_commit=False, class_=AsyncSession + ) + + async def asyncTearDown(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(db_schema.Base.metadata.drop_all) + await self.engine.dispose() + await super().asyncTearDown() + + async def test_create_asset(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + path="/experiment/step1", + user="testuser", + tags=["tag1", "tag2"], + state=db_schema.AssetState.ASSET_STATE_ACTIVE_WRITE, + created_at=datetime.datetime(2026, 1, 1, 10, 0, 0), + ) + session.add(asset) + await session.commit() + + generated_uuid = asset.asset_uuid + self.assertIsNotNone(generated_uuid) + + result = await session.execute( + select(db_schema.Asset).filter_by(asset_uuid=generated_uuid) + ) + fetched = result.scalars().first() + self.assertIsNotNone(fetched) + self.assertEqual(fetched.path, "/experiment/step1") + self.assertEqual(fetched.tags, ["tag1", "tag2"]) + self.assertEqual( + fetched.state, db_schema.AssetState.ASSET_STATE_ACTIVE_WRITE + ) + + async def test_update_asset_state(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-456", + path="/experiment/step2", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_ACTIVE_WRITE, + ) + session.add(asset) + await session.commit() + + result = await session.execute( + select(db_schema.Asset).filter_by(asset_uuid="uuid-456") + ) + asset_fetch = result.scalars().first() + asset_fetch.state = db_schema.AssetState.ASSET_STATE_STORED + await session.commit() + + result = await session.execute( + select(db_schema.Asset).filter_by(asset_uuid="uuid-456") + ) + fetched = result.scalars().first() + self.assertEqual(fetched.state, db_schema.AssetState.ASSET_STATE_STORED) + + async def test_add_tier_path(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-789", + path="/experiment/step3", + user="testuser", + ) + backend0 = db_schema.StorageBackend( + level=0, + zone="us-east5-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ) + backend1 = db_schema.StorageBackend( + level=1, + multi_regions=["us-central1", "us-east1"], + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + tier_path0 = db_schema.TierPath( + asset_uuid="uuid-789", + storage_backend=backend0, + path="/lustre/path/1", + ) + tier_path1 = db_schema.TierPath( + asset_uuid="uuid-789", + storage_backend=backend1, + path="/gcs/path/2", + ) + session.add(asset) + session.add(backend0) + session.add(backend1) + session.add(tier_path0) + session.add(tier_path1) + await session.commit() + + result = await session.execute( + select(db_schema.Asset) + .options( + sqlalchemy.orm.selectinload( + db_schema.Asset.tier_paths + ).selectinload(db_schema.TierPath.storage_backend) + ) + .filter_by(asset_uuid="uuid-789") + ) + fetched = result.scalars().first() + self.assertLen(fetched.tier_paths, 2) + tp0 = next( + tp for tp in fetched.tier_paths if tp.storage_backend.level == 0 + ) + tp1 = next( + tp for tp in fetched.tier_paths if tp.storage_backend.level == 1 + ) + self.assertEqual(tp0.path, "/lustre/path/1") + self.assertEqual(tp0.storage_backend.zone, "us-east5-a") + self.assertEqual(tp1.path, "/gcs/path/2") + self.assertEqual( + tp1.storage_backend.multi_regions, + ["us-central1", "us-east1"], + ) + + async def test_add_tier_path_fails_multiple_locations(self) -> None: + async with self.session_maker() as session: + backend = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + session.add(backend) + await session.commit() + + asset = db_schema.Asset( + asset_uuid="uuid-dup-locations", + path="/experiment/dup_locations", + user="testuser", + ) + + tp1 = db_schema.TierPath( + storage_backend=backend, + path="/path1", + ) + asset.tier_paths.append(tp1) + session.add(asset) + + await session.commit() + + tp2 = db_schema.TierPath( + storage_backend=backend, + path="/dup_path", + ) + asset.tier_paths.append(tp2) + with self.assertRaisesRegex( + sqlalchemy.exc.IntegrityError, + "UNIQUE constraint failed: tier_paths.asset_uuid," + " tier_paths.storage_backend_id", + ): + await session.commit() + + async def test_storage_backend_fails_multiple_locations_zone(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-distinct-backends-zone", + path="/experiment/distinct_backends_zone", + user="testuser", + ) + b1 = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ) + b2 = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ) + tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") + tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") + asset.tier_paths.extend([tp1, tp2]) + session.add(asset) + with self.assertRaisesRegex(ValueError, "Duplicate zone"): + await session.commit() + + async def test_storage_backend_fails_multiple_locations_region(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-distinct-backends-region", + path="/experiment/distinct_backends_region", + user="testuser", + ) + b1 = db_schema.StorageBackend( + level=0, + region="us-central1", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + b2 = db_schema.StorageBackend( + level=0, + region="us-central1", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") + tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") + asset.tier_paths.extend([tp1, tp2]) + session.add(asset) + with self.assertRaisesRegex(ValueError, "Duplicate region"): + await session.commit() + + async def test_storage_backend_fails_multiple_locations_multi_regions( + self, + ) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-distinct-backends-mr", + path="/experiment/distinct_backends_mr", + user="testuser", + ) + b1 = db_schema.StorageBackend( + level=0, + multi_regions=["us-central1", "us-east1"], + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + # Order of regions shouldn't matter + b2 = db_schema.StorageBackend( + level=0, + multi_regions=["us-east1", "us-central1"], + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + tp1 = db_schema.TierPath(storage_backend=b1, path="/path1") + tp2 = db_schema.TierPath(storage_backend=b2, path="/path2") + asset.tier_paths.extend([tp1, tp2]) + session.add(asset) + with self.assertRaisesRegex(ValueError, "Duplicate multi_regions"): + await session.commit() + + async def test_add_tier_path_fails_no_locations(self) -> None: + async with self.session_maker() as session: + with self.assertRaisesRegex( + sqlalchemy.exc.IntegrityError, "check_mutually_exclusive_locations" + ): + invalid_backend_empty = db_schema.StorageBackend( + level=0, + ) + session.add(invalid_backend_empty) + await session.commit() + + async def test_asset_job_queue(self) -> None: + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-queue", + path="/experiment/queue", + user="testuser", + ) + backend = db_schema.StorageBackend( + level=0, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ) + tier_path = db_schema.TierPath( + asset_uuid="uuid-queue", storage_backend=backend, path="/path1" + ) + session.add_all([asset, backend, tier_path]) + await session.flush() + + job1 = db_schema.AssetJob( + asset_uuid="uuid-queue", + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tier_path.id, + ) + job2 = db_schema.AssetJob( + asset_uuid="uuid-queue", + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tier_path.id, + ) + session.add_all([job1, job2]) + await session.commit() + + result = await session.execute( + select(db_schema.AssetJob) + .filter_by(asset_uuid="uuid-queue") + .order_by(db_schema.AssetJob.id) + ) + jobs = result.scalars().all() + self.assertLen(jobs, 2) + self.assertEqual( + jobs[0].request_type, db_schema.RequestType.REQUEST_TYPE_COPY + ) + self.assertEqual( + jobs[1].request_type, + db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + ) + + jobs[0].status = db_schema.JobStatus.JOB_STATUS_COMPLETED + await session.commit() + + result = await session.execute( + select(db_schema.AssetJob).filter_by(id=job1.id) + ) + fetched_job = result.scalars().first() + self.assertEqual( + fetched_job.status, db_schema.JobStatus.JOB_STATUS_COMPLETED + ) + + async def test_create_asset_duplicates_allowed_for_deleted_incomplete(self): + # Verify we can have duplicate path for DELETED or INCOMPLETE states + async with self.session_maker() as session: + asset1 = db_schema.Asset( + path="/experiment/dup_allow", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_DELETED, + ) + asset2 = db_schema.Asset( + path="/experiment/dup_allow", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_INCOMPLETE, + ) + session.add(asset1) + session.add(asset2) + await session.commit() + + async def test_create_asset_duplicates_blocked_for_active_stored( + self, + ) -> None: + async with self.session_maker() as session: + asset3 = db_schema.Asset( + path="/experiment/dup_block", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_ACTIVE_WRITE, + ) + session.add(asset3) + await session.commit() + + async with self.session_maker() as session: + asset4 = db_schema.Asset( + path="/experiment/dup_block", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_STORED, + ) + session.add(asset4) + with self.assertRaisesRegex( + sqlalchemy.exc.IntegrityError, + "UNIQUE constraint failed: assets.path", + ): + await session.commit() + + @parameterized.named_parameters( + dict( + testcase_name="same_backend_type", + backend1=db_schema.StorageBackend( + level=1, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ), + backend2=db_schema.StorageBackend( + level=1, + zone="us-central1-b", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ), + expected_exception=ValueError, + expected_regex="same backend_type", + ), + dict( + testcase_name="duplicate_zone_at_same_level", + backend1=db_schema.StorageBackend( + level=1, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ), + backend2=db_schema.StorageBackend( + level=1, + zone="us-central1-a", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + ), + expected_exception=ValueError, + expected_regex="Duplicate zone", + ), + dict( + testcase_name="duplicate_region_at_same_level", + backend1=db_schema.StorageBackend( + level=1, + region="us-central1", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ), + backend2=db_schema.StorageBackend( + level=1, + region="us-central1", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ), + expected_exception=ValueError, + expected_regex="Duplicate region", + ), + dict( + testcase_name="duplicate_multi_regions_at_same_level", + backend1=db_schema.StorageBackend( + level=1, + multi_regions=["us-central1", "us-east1"], + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ), + backend2=db_schema.StorageBackend( + level=1, + multi_regions=["us-east1", "us-central1"], + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + ), + expected_exception=ValueError, + expected_regex="Duplicate multi_regions", + ), + ) + async def test_storage_backend_validation( + self, + backend1, + backend2, + expected_exception, + expected_regex, + ) -> None: + async with self.session_maker() as session: + b1 = backend1 + b2 = backend2 + session.add(b1) + session.add(b2) + with self.assertRaisesRegex(expected_exception, expected_regex): + await session.commit() + + +def _worker_add_job( + db_path: str, request_type_val: int, tp_id: int | None +) -> tuple[int, bool]: + engine = sqlalchemy.create_engine( + f"sqlite:///{db_path}", connect_args={"timeout": 30} + ) + db_session = sqlalchemy.orm.sessionmaker(engine) + with db_session() as session: + job = db_schema.AssetJob( + asset_uuid="uuid-queue-multi", + request_type=db_schema.RequestType(request_type_val), + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tp_id, + ) + try: + session.add(job) + session.commit() + return (request_type_val, True) + except sqlalchemy.exc.IntegrityError: + return (request_type_val, False) + + +def _worker_create_asset(db_path: str) -> bool: + engine = sqlalchemy.create_engine( + f"sqlite:///{db_path}", connect_args={"timeout": 30} + ) + db_session = sqlalchemy.orm.sessionmaker(engine) + with db_session() as session: + asset = db_schema.Asset( + path="/experiment/race_condition", + user="testuser", + state=db_schema.AssetState.ASSET_STATE_ACTIVE_WRITE, + ) + session.add(asset) + try: + session.commit() + return True + except sqlalchemy.exc.IntegrityError: + return False + + +class DbSchemaMultiprocessTest( + absltest.TestCase, unittest.IsolatedAsyncioTestCase +): + + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + tmp_file = self.create_tempfile() + self.db_path = tmp_file.full_path + + self.engine = create_async_engine( + f"sqlite+aiosqlite:///{self.db_path}", + echo=True, + ) + + async with self.engine.begin() as conn: + await conn.exec_driver_sql("PRAGMA journal_mode=WAL") + await conn.run_sync(db_schema.Base.metadata.create_all) + + self.session_maker = sessionmaker( + self.engine, expire_on_commit=False, class_=AsyncSession + ) + + async def asyncTearDown(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(db_schema.Base.metadata.drop_all) + await self.engine.dispose() + await super().asyncTearDown() + + def test_asset_job_queue_multiprocess(self) -> None: + async def _setup(): + async with self.session_maker() as session: + asset = db_schema.Asset( + asset_uuid="uuid-queue-multi", + path="/experiment/queue-multi", + user="testuser", + ) + sb = db_schema.StorageBackend(level=0, zone="us-central1-a") + tp = db_schema.TierPath( + asset_uuid="uuid-queue-multi", storage_backend=sb, path="/path1" + ) + session.add_all([asset, sb, tp]) + await session.commit() + return tp.id + + tp_id = asyncio.run(_setup()) + + job_types = [ + (int(db_schema.RequestType.REQUEST_TYPE_COPY), tp_id), + (int(db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE), tp_id), + (int(db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS), None), + ] + + with multiprocessing.Pool(processes=3) as pool: + results = pool.starmap( + _worker_add_job, + [(self.db_path, jt, target_id) for jt, target_id in job_types], + ) + + for request_type_val, success in results: + with self.subTest(request_type=request_type_val): + self.assertTrue(success) + + async def _verify(): + async with self.session_maker() as session: + result = await session.execute( + select(db_schema.AssetJob).filter_by(asset_uuid="uuid-queue-multi") + ) + jobs = result.scalars().all() + self.assertLen(jobs, 3) + + found_types = [j.request_type for j in jobs] + expected_types = [ + db_schema.RequestType.REQUEST_TYPE_COPY, + db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS, + ] + self.assertCountEqual(found_types, expected_types) + + asyncio.run(_verify()) + + def test_create_asset_multiprocess(self) -> None: + with multiprocessing.Pool(processes=5) as pool: + results = pool.map(_worker_create_asset, [self.db_path] * 5) + + successes = results.count(True) + self.assertEqual(successes, 1) + + +if __name__ == "__main__": + absltest.main() diff --git a/checkpoint/pyproject.toml b/checkpoint/pyproject.toml index c5158708c..2e985edcb 100644 --- a/checkpoint/pyproject.toml +++ b/checkpoint/pyproject.toml @@ -76,7 +76,11 @@ testing = [ 'grain', ] tiering_service = [ + 'aiosqlite', + 'greenlet', 'grpcio-tools>=1.80.0', + 'pysqlite3', + 'sqlalchemy>=1.4.0', ] [tool.flit.sdist]