diff --git a/dev-requirements.in b/dev-requirements.in index 7159812e26..3cb16d8d3b 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -5,6 +5,7 @@ hypothesis joblib mock pytest +pytest-asyncio pytest-cov mypy pre-commit diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index b18610e145..73737f3a6c 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -98,12 +98,13 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon return await agent.async_create( context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp ) - return await asyncio.to_thread( + return await asyncio.get_running_loop().run_in_executor( + None, agent.create, - context=context, - inputs=inputs, - output_prefix=request.output_prefix, - task_template=tmp, + context, + request.output_prefix, + tmp, + inputs, ) @agent_exception_handler @@ -112,7 +113,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) logger.info(f"{agent.task_type} agent start checking the status of the job") if agent.asynchronous: return await agent.async_get(context=context, resource_meta=request.resource_meta) - return await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta) + return await asyncio.get_running_loop().run_in_executor(None, agent.get, context, request.resource_meta) @agent_exception_handler async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: @@ -120,4 +121,4 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon logger.info(f"{agent.task_type} agent start deleting the job") if agent.asynchronous: return await agent.async_delete(context=context, resource_meta=request.resource_meta) - return await asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta) + return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index d30d870cd8..e9555b2026 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -66,7 +66,28 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT return DeleteTaskResponse() +class AsyncDummyAgent(AgentBase): + def __init__(self): + super().__init__(task_type="async_dummy", asynchronous=True) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + return DeleteTaskResponse() + + AgentRegistry.register(DummyAgent()) +AgentRegistry.register(AsyncDummyAgent()) task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") task_metadata = task.TaskMetadata( @@ -102,6 +123,14 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT custom={}, ) +async_dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="async_dummy", + custom={}, +) + def test_dummy_agent(): ctx = MagicMock(spec=grpc.ServicerContext) @@ -126,24 +155,48 @@ def __init__(self, **kwargs): t.execute() +@pytest.mark.asyncio +async def test_async_dummy_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("async_dummy") + metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) + assert res.resource_meta == metadata_bytes + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + res = await agent.async_delete(ctx, metadata_bytes) + assert res == DeleteTaskResponse() + + +@pytest.mark.asyncio async def run_agent_server(): service = AsyncAgentService() ctx = MagicMock(spec=grpc.ServicerContext) request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() ) - + async_request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() + ) + fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + res = await service.CreateTask(request, ctx) assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED + res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + assert isinstance(res, DeleteTaskResponse) - await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - res = await service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx) + res = await service.CreateTask(async_request, ctx) + assert res.resource_meta == metadata_bytes + res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + assert res.resource.state == SUCCEEDED + res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + assert isinstance(res, DeleteTaskResponse) - assert res.resource.state == PERMANENT_FAILURE + res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) + assert res is None def test_agent_server():