Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix py38 aync agent service and add async agent test #1866

Merged
1 change: 1 addition & 0 deletions dev-requirements.in
Expand Up @@ -5,6 +5,7 @@ hypothesis
joblib
mock
pytest
pytest-asyncio
pytest-cov
mypy
pre-commit
Expand Down
15 changes: 8 additions & 7 deletions flytekit/extend/backend/agent_service.py
Expand Up @@ -98,12 +98,13 @@
return await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
)
return await asyncio.to_thread(
return await asyncio.get_running_loop().run_in_executor(

Check warning on line 101 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L101

Added line #L101 was not covered by tests
None,
agent.create,
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
context,
request.output_prefix,
tmp,
inputs,
)

@agent_exception_handler
Expand All @@ -112,12 +113,12 @@
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent.asynchronous:
return await agent.async_get(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.get, context, request.resource_meta)

Check warning on line 116 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L116

Added line #L116 was not covered by tests

@agent_exception_handler
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent.asynchronous:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta)
return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta)

Check warning on line 124 in flytekit/extend/backend/agent_service.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/agent_service.py#L124

Added line #L124 was not covered by tests
63 changes: 58 additions & 5 deletions tests/flytekit/unit/extend/test_agent.py
Expand Up @@ -66,7 +66,28 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT
return DeleteTaskResponse()


class AsyncDummyAgent(AgentBase):
def __init__(self):
super().__init__(task_type="async_dummy", asynchronous=True)

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> CreateTaskResponse:
return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8"))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
return GetTaskResponse(resource=Resource(state=SUCCEEDED))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
return DeleteTaskResponse()


AgentRegistry.register(DummyAgent())
AgentRegistry.register(AsyncDummyAgent())

task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version")
task_metadata = task.TaskMetadata(
Expand Down Expand Up @@ -102,6 +123,14 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT
custom={},
)

async_dummy_template = TaskTemplate(
id=task_id,
metadata=task_metadata,
interface=interfaces,
type="async_dummy",
custom={},
)


def test_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
Expand All @@ -126,24 +155,48 @@ def __init__(self, **kwargs):
t.execute()


@pytest.mark.asyncio
async def test_async_dummy_agent():
ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent("async_dummy")
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")
res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs)
assert res.resource_meta == metadata_bytes
res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
res = await agent.async_delete(ctx, metadata_bytes)
assert res == DeleteTaskResponse()


@pytest.mark.asyncio
async def run_agent_server():
service = AsyncAgentService()
ctx = MagicMock(spec=grpc.ServicerContext)
request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl()
)

async_request = CreateTaskRequest(
inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl()
)
fake_agent = "fake"
metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")

res = await service.CreateTask(request, ctx)
assert res.resource_meta == metadata_bytes

res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
assert res.resource.state == SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx)
res = await service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx)
res = await service.CreateTask(async_request, ctx)
assert res.resource_meta == metadata_bytes
res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert res.resource.state == SUCCEEDED
res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx)
assert isinstance(res, DeleteTaskResponse)

assert res.resource.state == PERMANENT_FAILURE
res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx)
assert res is None


def test_agent_server():
Expand Down