diff --git a/tests/integration/test_update_task_v2_perf.py b/tests/integration/test_update_task_v2_perf.py index de2a39d0..671427a5 100644 --- a/tests/integration/test_update_task_v2_perf.py +++ b/tests/integration/test_update_task_v2_perf.py @@ -73,17 +73,17 @@ # --------------------------------------------------------------------------- @worker_task(task_definition_name="perf_type_a", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_a(task_index: int = 0) -> dict: +async def perf_worker_a(task_index: int = 0) -> dict: return {"worker": "perf_type_a", "task_index": task_index} @worker_task(task_definition_name="perf_type_b", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_b(task_index: int = 0) -> dict: +async def perf_worker_b(task_index: int = 0) -> dict: return {"worker": "perf_type_b", "task_index": task_index} @worker_task(task_definition_name="perf_type_c", thread_count=WORKER_THREADS, register_task_def=True) -def perf_worker_c(task_index: int = 0) -> dict: +async def perf_worker_c(task_index: int = 0) -> dict: return {"worker": "perf_type_c", "task_index": task_index} diff --git a/tests/integration/test_v2_fallback_intg.py b/tests/integration/test_v2_fallback_intg.py new file mode 100644 index 00000000..4f2258cf --- /dev/null +++ b/tests/integration/test_v2_fallback_intg.py @@ -0,0 +1,174 @@ +""" +Integration test for update-task-v2 graceful degradation. + +Verifies that when update-task-v2 is unavailable (or available), the SDK +correctly auto-detects and falls back to v1 while still completing workflows. + +Run: + python -m pytest tests/integration/test_v2_fallback_intg.py -v -s +""" + +import logging +import os +import sys +import time +import threading +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from conductor.client.automator.task_handler import TaskHandler +from conductor.client.configuration.configuration import Configuration +from conductor.client.http.models.start_workflow_request import StartWorkflowRequest +from conductor.client.http.models.workflow_def import WorkflowDef +from conductor.client.http.models.workflow_task import WorkflowTask +from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient +from conductor.client.orkes.orkes_workflow_client import OrkesWorkflowClient +from conductor.client.worker.worker_task import worker_task + +logger = logging.getLogger(__name__) + +WORKFLOW_NAME = "test_v2_fallback_workflow" +WORKFLOW_VERSION = 1 + + +# --------------------------------------------------------------------------- +# Workers +# --------------------------------------------------------------------------- + +@worker_task(task_definition_name="v2_fallback_task_a", thread_count=3, register_task_def=True) +async def fallback_worker_a(task_index: int = 0) -> dict: + return {"worker": "v2_fallback_task_a", "task_index": task_index} + + +@worker_task(task_definition_name="v2_fallback_task_b", thread_count=3, register_task_def=True) +async def fallback_worker_b(task_index: int = 0) -> dict: + return {"worker": "v2_fallback_task_b", "task_index": task_index} + + +# --------------------------------------------------------------------------- +# Test +# --------------------------------------------------------------------------- + +class TestV2FallbackIntegration(unittest.TestCase): + + @classmethod + def setUpClass(cls): + from tests.integration.conftest import skip_if_server_unavailable + skip_if_server_unavailable() + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(process)d] %(name)-45s %(levelname)-10s %(message)s", + ) + logging.getLogger("conductor.client").setLevel(logging.WARNING) + + cls.config = Configuration() + cls.workflow_client = OrkesWorkflowClient(cls.config) + cls.metadata_client = OrkesMetadataClient(cls.config) + + def test_0_register_workflow(self): + """Register workflow with 2 task types (3 tasks each).""" + tasks = [] + idx = 0 + for task_type, count in [("v2_fallback_task_a", 3), ("v2_fallback_task_b", 3)]: + for i in range(count): + idx += 1 + tasks.append( + WorkflowTask( + name=task_type, + task_reference_name=f"{task_type}_{i + 1}", + input_parameters={"task_index": idx}, + ) + ) + + workflow = WorkflowDef(name=WORKFLOW_NAME, version=WORKFLOW_VERSION) + workflow._tasks = tasks + try: + self.metadata_client.update_workflow_def(workflow, overwrite=True) + except Exception: + self.metadata_client.register_workflow_def(workflow, overwrite=True) + print(f"\n Registered workflow '{WORKFLOW_NAME}' with {len(tasks)} tasks") + + def test_1_workflows_complete_with_v2_or_fallback(self): + """Start workers and verify workflows complete regardless of v2 support. + + This test doesn't force a 404 — it runs against the real server. + If v2 is available, it uses v2. If not, it auto-detects and falls back. + Either way, all workflows should complete successfully. + """ + workflow_count = 5 + + handler_ready = threading.Event() + handler_ref = {} + + def _run_workers(): + handler = TaskHandler( + configuration=self.config, + scan_for_annotated_workers=True, + ) + handler_ref["h"] = handler + handler.start_processes() + handler_ready.set() + handler_ref["stop"] = threading.Event() + handler_ref["stop"].wait() + handler.stop_processes() + + worker_thread = threading.Thread(target=_run_workers, daemon=True) + worker_thread.start() + handler_ready.wait(timeout=30) + self.assertTrue(handler_ready.is_set(), "Workers failed to start within 30s") + time.sleep(3) # Warm up + + try: + # Submit workflows + workflow_ids = [] + for i in range(workflow_count): + req = StartWorkflowRequest() + req.name = WORKFLOW_NAME + req.version = WORKFLOW_VERSION + req.input = {"run_index": i} + wf_id = self.workflow_client.start_workflow(start_workflow_request=req) + workflow_ids.append(wf_id) + + print(f"\n Submitted {len(workflow_ids)} workflows") + + # Wait for completion + deadline = time.time() + 60 # 60s timeout + pending = set(workflow_ids) + completed = 0 + failed = 0 + + while pending and time.time() < deadline: + still_pending = set() + for wf_id in pending: + try: + wf = self.workflow_client.get_workflow(wf_id, include_tasks=False) + except Exception: + still_pending.add(wf_id) + continue + + if wf.status == "COMPLETED": + completed += 1 + elif wf.status in ("FAILED", "TERMINATED", "TIMED_OUT"): + failed += 1 + logger.warning("Workflow %s ended with status %s", wf_id, wf.status) + else: + still_pending.add(wf_id) + + pending = still_pending + if pending: + time.sleep(1) + + print(f" Results: {completed} completed, {failed} failed, {len(pending)} pending") + + self.assertEqual(len(pending), 0, f"{len(pending)} workflows did not complete in time") + self.assertEqual(completed, workflow_count, f"Expected {workflow_count} completed, got {completed}") + + finally: + handler_ref.get("stop", threading.Event()).set() + worker_thread.join(timeout=15) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/unit/automator/test_v2_fallback.py b/tests/unit/automator/test_v2_fallback.py new file mode 100644 index 00000000..3b22fb4f --- /dev/null +++ b/tests/unit/automator/test_v2_fallback.py @@ -0,0 +1,400 @@ +""" +Unit tests for update-task-v2 graceful degradation to v1. + +Tests both sync TaskRunner and async AsyncTaskRunner to verify: +- On 404/405 from update_task_v2, falls back to update_task (v1) +- The _use_update_v2 flag is set to False after first fallback +- Subsequent calls go directly to v1 (skip v2) +- The current task result is still persisted via v1 during fallback +""" + +import asyncio +import logging +import unittest +from unittest.mock import patch, Mock, AsyncMock + +from conductor.client.automator.task_runner import TaskRunner +from conductor.client.automator.async_task_runner import AsyncTaskRunner +from conductor.client.configuration.configuration import Configuration +from conductor.client.http.api.task_resource_api import TaskResourceApi +from conductor.client.http.models.task import Task +from conductor.client.http.models.task_result import TaskResult +from conductor.client.http.models.task_result_status import TaskResultStatus +from conductor.client.http.rest import ApiException +from conductor.client.worker.worker import Worker +from tests.unit.resources.workers import ClassWorker + + +class TestTaskRunnerV2Fallback(unittest.TestCase): + """Tests for sync TaskRunner v2 -> v1 fallback.""" + + TASK_ID = 'test_task_id' + WORKFLOW_INSTANCE_ID = 'test_workflow_id' + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @patch('time.sleep', Mock(return_value=None)) + def test_fallback_on_404(self): + """On 404 from update_task_v2, should fall back to update_task and return None.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='task_id_confirmation' + ) as mock_v1: + runner = self._create_runner() + self.assertTrue(runner._use_update_v2) + + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._use_update_v2) + mock_v1.assert_called_once() + + @patch('time.sleep', Mock(return_value=None)) + def test_fallback_on_405(self): + """On 405 from update_task_v2, should fall back to update_task and return None.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=405, reason="Method Not Allowed") + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='task_id_confirmation' + ) as mock_v1: + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._use_update_v2) + mock_v1.assert_called_once() + + @patch('time.sleep', Mock(return_value=None)) + def test_subsequent_calls_use_v1_directly(self): + """After fallback, subsequent calls should go to v1 directly, skipping v2.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ) as mock_v2: + with patch.object( + TaskResourceApi, 'update_task', + return_value='ok' + ) as mock_v1: + runner = self._create_runner() + + # First call triggers fallback + runner._TaskRunner__update_task(self._create_task_result()) + self.assertEqual(mock_v2.call_count, 1) + self.assertEqual(mock_v1.call_count, 1) + + # Second call should skip v2 entirely + runner._TaskRunner__update_task(self._create_task_result()) + self.assertEqual(mock_v2.call_count, 1) # Still 1 — not called again + self.assertEqual(mock_v1.call_count, 2) + + @patch('time.sleep', Mock(return_value=None)) + def test_v2_success_no_fallback(self): + """When v2 succeeds, should return next task and not touch v1.""" + next_task = Task(task_id='next_task', workflow_instance_id='wf_2') + with patch.object( + TaskResourceApi, 'update_task_v2', + return_value=next_task + ): + with patch.object( + TaskResourceApi, 'update_task', + return_value='ok' + ) as mock_v1: + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertEqual(result, next_task) + self.assertTrue(runner._use_update_v2) + mock_v1.assert_not_called() + + @patch('time.sleep', Mock(return_value=None)) + def test_non_404_error_does_not_trigger_fallback(self): + """A 500 error should retry normally, not trigger v1 fallback.""" + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=500, reason="Internal Server Error") + ): + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + # All retries exhausted, still _use_update_v2 (not a 404/405) + self.assertTrue(runner._use_update_v2) + self.assertIsNone(result) + + @patch('time.sleep', Mock(return_value=None)) + def test_v1_fallback_failure_retries(self): + """If v1 also fails during fallback, should retry with backoff.""" + call_count = {'v1': 0} + + def v1_side_effect(**kwargs): + call_count['v1'] += 1 + if call_count['v1'] <= 2: + raise Exception("v1 also down") + return 'ok' + + with patch.object( + TaskResourceApi, 'update_task_v2', + side_effect=ApiException(status=404, reason="Not Found") + ): + with patch.object( + TaskResourceApi, 'update_task', + side_effect=v1_side_effect + ): + runner = self._create_runner() + result = runner._TaskRunner__update_task(self._create_task_result()) + + self.assertFalse(runner._use_update_v2) + # First v1 call fails (immediate fallback), then retries succeed + self.assertIsNone(result) + + def _create_runner(self): + return TaskRunner( + configuration=Configuration(), + worker=ClassWorker('task') + ) + + def _create_task_result(self): + return TaskResult( + task_id=self.TASK_ID, + workflow_instance_id=self.WORKFLOW_INSTANCE_ID, + worker_id='test_worker', + status=TaskResultStatus.COMPLETED, + output_data={'result': 42} + ) + + +class TestAsyncTaskRunnerV2Fallback(unittest.TestCase): + """Tests for async AsyncTaskRunner v2 -> v1 fallback.""" + + TASK_ID = 'test_task_id' + WORKFLOW_INSTANCE_ID = 'test_workflow_id' + AUTH_TOKEN = 'test_token' + + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + def test_fallback_on_404(self): + """On 404 from async update_task_v2, should fall back to update_task.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_fallback', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + self.assertTrue(runner._use_update_v2) + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._use_update_v2) + runner.async_task_client.update_task.assert_called_once() + + asyncio.run(run_test()) + + def test_fallback_on_405(self): + """On 405 from async update_task_v2, should fall back to update_task.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_fallback_405', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=405, reason="Method Not Allowed") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertIsNone(result) + self.assertFalse(runner._use_update_v2) + runner.async_task_client.update_task.assert_called_once() + + asyncio.run(run_test()) + + def test_subsequent_calls_use_v1_directly(self): + """After fallback, subsequent async calls should go to v1 directly.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_subsequent', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + # First call triggers fallback + await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + self.assertEqual(runner.async_task_client.update_task_v2.call_count, 1) + self.assertEqual(runner.async_task_client.update_task.call_count, 1) + + # Second call skips v2 + await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + self.assertEqual(runner.async_task_client.update_task_v2.call_count, 1) # Still 1 + self.assertEqual(runner.async_task_client.update_task.call_count, 2) + + asyncio.run(run_test()) + + def test_v2_success_no_fallback(self): + """When async v2 succeeds, should return next task and not touch v1.""" + + async def simple_worker(value: int) -> dict: + return {'result': value} + + worker = Worker( + task_definition_name='test_v2_success', + execute_function=simple_worker, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + next_task = Task(task_id='next_task', workflow_instance_id='wf_2') + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + runner.async_task_client.update_task_v2 = AsyncMock(return_value=next_task) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + result = await runner._AsyncTaskRunner__async_update_task(self._create_task_result()) + + self.assertEqual(result, next_task) + self.assertTrue(runner._use_update_v2) + runner.async_task_client.update_task.assert_not_called() + + asyncio.run(run_test()) + + def test_end_to_end_with_fallback(self): + """Full end-to-end: poll -> execute -> update_v2 fails -> fallback to v1.""" + + async def async_worker_fn(value: int) -> dict: + return {'result': value * 2} + + worker = Worker( + task_definition_name='test_e2e_fallback', + execute_function=async_worker_fn, + thread_count=1 + ) + + config = Configuration() + config.AUTH_TOKEN = self.AUTH_TOKEN + runner = AsyncTaskRunner(worker=worker, configuration=config) + + mock_task = Task() + mock_task.task_id = self.TASK_ID + mock_task.workflow_instance_id = self.WORKFLOW_INSTANCE_ID + mock_task.task_def_name = 'test_e2e_fallback' + mock_task.input_data = {'value': 10} + mock_task.status = 'SCHEDULED' + + async def run_test(): + runner.async_api_client = AsyncMock() + runner.async_task_client = AsyncMock() + runner._semaphore = asyncio.Semaphore(1) + + # batch_poll returns the task once, then empty (stops the tight loop) + runner.async_task_client.batch_poll = AsyncMock( + side_effect=[[mock_task], []] + ) + runner.async_task_client.update_task_v2 = AsyncMock( + side_effect=ApiException(status=404, reason="Not Found") + ) + runner.async_task_client.update_task = AsyncMock(return_value='ok') + + await runner.run_once() + # Let the background coroutine finish + await asyncio.sleep(0.2) + + # v2 was attempted, then fell back to v1 + runner.async_task_client.update_task_v2.assert_called_once() + # v1 called once for the fallback, possibly once more for the + # re-polled empty batch (but at least once) + self.assertGreaterEqual(runner.async_task_client.update_task.call_count, 1) + + # First v1 call should have the correct task result + v1_call = runner.async_task_client.update_task.call_args_list[0] + task_result = v1_call.kwargs['body'] + self.assertEqual(task_result.status, TaskResultStatus.COMPLETED) + self.assertEqual(task_result.output_data, {'result': 20}) + + # Flag should be flipped + self.assertFalse(runner._use_update_v2) + + asyncio.run(run_test()) + + def _create_task_result(self): + return TaskResult( + task_id=self.TASK_ID, + workflow_instance_id=self.WORKFLOW_INSTANCE_ID, + worker_id='test_worker', + status=TaskResultStatus.COMPLETED, + output_data={'result': 42} + ) + + +if __name__ == '__main__': + unittest.main()