Skip to content

Commit

Permalink
Fix py38 aync agent service and add async agent test (#1866)
Browse files Browse the repository at this point in the history
Signed-off-by: Future Outlier <eric901201@gmai.com>
Co-authored-by: Future Outlier <eric901201@gmai.com>
  • Loading branch information
Future-Outlier and Future Outlier committed Oct 4, 2023
1 parent 16cee4d commit 71515a8
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 12 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Expand Up @@ -5,6 +5,7 @@ hypothesis
joblib
mock
pytest
pytest-asyncio
pytest-cov
mypy
pre-commit
Expand Down
15 changes: 8 additions & 7 deletions flytekit/extend/backend/agent_service.py
Expand Up @@ -98,12 +98,13 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon
return await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
)
return await asyncio.to_thread(
return await asyncio.get_running_loop().run_in_executor(
None,
agent.create,
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
context,
request.output_prefix,
tmp,
inputs,
)

@agent_exception_handler
Expand All @@ -112,12 +113,12 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext)
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent.asynchronous:
return await agent.async_get(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.get, context, request.resource_meta)

@agent_exception_handler
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent.asynchronous:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta)
63 changes: 58 additions & 5 deletions tests/flytekit/unit/extend/test_agent.py
Expand Up @@ -66,7 +66,28 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT
return DeleteTaskResponse()


class AsyncDummyAgent(AgentBase):
def __init__(self):
super().__init__(task_type="async_dummy", asynchronous=True)

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> CreateTaskResponse:
return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8"))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
return GetTaskResponse(resource=Resource(state=SUCCEEDED))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()


AgentRegistry.register(DummyAgent())
AgentRegistry.register(AsyncDummyAgent())

task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version")
task_metadata = task.TaskMetadata(
Expand Down Expand Up @@ -102,6 +123,14 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT
custom={},
)

async_dummy_template = TaskTemplate(
id=task_id,
metadata=task_metadata,
interface=interfaces,
type="async_dummy",
custom={},
)


def test_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
Expand All @@ -126,24 +155,48 @@ def __init__(self, **kwargs):
t.execute()


@pytest.mark.asyncio
async def test_async_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent("async_dummy")
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")
res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs)
assert res.resource_meta == metadata_bytes
res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
res = await agent.async_delete(ctx, metadata_bytes)
assert res == DeleteTaskResponse()


@pytest.mark.asyncio
async def run_agent_server():
service = AsyncAgentService()
ctx = MagicMock(spec=grpc.ServicerContext)
request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl()
)

async_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl()
)
fake_agent = "fake"
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")

res = await service.CreateTask(request, ctx)
assert res.resource_meta == metadata_bytes

res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
assert res.resource.state == SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
res = await service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx)
res = await service.CreateTask(async_request, ctx)
assert res.resource_meta == metadata_bytes
res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert res.resource.state == SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

assert res.resource.state == PERMANENT_FAILURE
res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None


def test_agent_server():
Expand Down

0 comments on commit 71515a8

Please sign in to comment.