diff --git a/DEVELOPING.md b/DEVELOPING.md index 0b455dfd..2033f3d0 100644 --- a/DEVELOPING.md +++ b/DEVELOPING.md @@ -45,7 +45,7 @@ pdm run pre-commit install To run unit tests: ``` -pdm run pytest +pdm run pytest tests ``` NOTE: The tests need a Postgres database running on `localhost:5432`. To start diff --git a/dbos/_client.py b/dbos/_client.py index deab4906..ec0e3e5d 100644 --- a/dbos/_client.py +++ b/dbos/_client.py @@ -158,6 +158,7 @@ def __init__( engine=system_database_engine, schema=dbos_system_schema, serializer=serializer, + executor_id=None, ) self._sys_db.check_connection() if application_database_url: diff --git a/dbos/_dbos.py b/dbos/_dbos.py index d906fce1..f1f69d92 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -460,6 +460,7 @@ def _launch(self, *, debug_mode: bool = False) -> None: debug_mode=debug_mode, schema=schema, serializer=self._serializer, + executor_id=GlobalParams.executor_id, ) assert self._config["database"]["db_engine_kwargs"] is not None if self._config["database_url"]: diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index 97054801..72e05faa 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -351,6 +351,7 @@ def create( engine: Optional[sa.Engine], schema: Optional[str], serializer: Serializer, + executor_id: Optional[str], debug_mode: bool = False, ) -> "SystemDatabase": """Factory method to create the appropriate SystemDatabase implementation based on URL.""" @@ -363,6 +364,7 @@ def create( engine=engine, schema=schema, serializer=serializer, + executor_id=executor_id, debug_mode=debug_mode, ) else: @@ -374,6 +376,7 @@ def create( engine=engine, schema=schema, serializer=serializer, + executor_id=executor_id, debug_mode=debug_mode, ) @@ -385,6 +388,7 @@ def __init__( engine: Optional[sa.Engine], schema: Optional[str], serializer: Serializer, + executor_id: Optional[str], debug_mode: bool = False, ): import sqlalchemy.dialects.postgresql as pg @@ -410,6 +414,8 @@ def __init__( self.notifications_map = ThreadSafeConditionDict() self.workflow_events_map = ThreadSafeConditionDict() + self.executor_id = executor_id + self._listener_thread_lock = threading.Lock() # Now we can run background processes @@ -1069,6 +1075,27 @@ def _record_operation_result_txn( error = result["error"] output = result["output"] assert error is None or output is None, "Only one of error or output can be set" + wf_executor_id_row = conn.execute( + sa.select( + SystemSchema.workflow_status.c.executor_id, + ).where( + SystemSchema.workflow_status.c.workflow_uuid == result["workflow_uuid"] + ) + ).fetchone() + assert wf_executor_id_row is not None + wf_executor_id = wf_executor_id_row[0] + if self.executor_id is not None and wf_executor_id != self.executor_id: + dbos_logger.debug( + f'Resetting executor_id from {wf_executor_id} to {self.executor_id} for workflow {result["workflow_uuid"]}' + ) + conn.execute( + sa.update(SystemSchema.workflow_status) + .values(executor_id=self.executor_id) + .where( + SystemSchema.workflow_status.c.workflow_uuid + == result["workflow_uuid"] + ) + ) sql = sa.insert(SystemSchema.operation_outputs).values( workflow_uuid=result["workflow_uuid"], function_id=result["function_id"], diff --git a/dbos/cli/migration.py b/dbos/cli/migration.py index f3d03de8..bce65d93 100644 --- a/dbos/cli/migration.py +++ b/dbos/cli/migration.py @@ -24,6 +24,7 @@ def migrate_dbos_databases( engine=None, schema=schema, serializer=DefaultSerializer(), + executor_id=None, ) sys_db.run_migrations() if app_database_url: diff --git a/tests/test_dbos.py b/tests/test_dbos.py index 4089039d..72cb281b 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -117,6 +117,40 @@ def noop() -> None: assert updated_at >= created_at +def test_eid_reset(dbos: DBOS) -> None: + @DBOS.step() + def test_step() -> str: + return "hello" + + @DBOS.workflow() + def test_workflow() -> str: + DBOS.set_event("started", 1) + DBOS.recv("run_step") + return test_step() + + wfuuid = str(uuid.uuid4()) + with SetWorkflowID(wfuuid): + wfh = dbos.start_workflow(test_workflow) + DBOS.get_event(wfuuid, "started") + with dbos._sys_db.engine.connect() as c: + c.execute( + sa.update(SystemSchema.workflow_status) + .values(executor_id="some_other_executor") + .where(SystemSchema.workflow_status.c.workflow_uuid == wfuuid) + ) + c.commit() + DBOS.send(wfuuid, 1, "run_step") + wfh.get_result() + with dbos._sys_db.engine.connect() as c: + x = c.execute( + sa.select(SystemSchema.workflow_status.c.executor_id).where( + SystemSchema.workflow_status.c.workflow_uuid == wfuuid + ) + ).fetchone() + assert x is not None + assert x[0] == "local" + + def test_child_workflow(dbos: DBOS) -> None: txn_counter: int = 0 wf_counter: int = 0 diff --git a/tests/test_schema_migration.py b/tests/test_schema_migration.py index a409fe13..e14842c8 100644 --- a/tests/test_schema_migration.py +++ b/tests/test_schema_migration.py @@ -135,8 +135,9 @@ def test_sqlite_systemdb_migration() -> None: engine_kwargs={}, engine=None, schema=None, - debug_mode=False, + executor_id=None, serializer=DefaultSerializer(), + debug_mode=False, ) # Run migrations