diff --git a/torchx/pipelines/kfp/adapter.py b/torchx/pipelines/kfp/adapter.py index efb439a7e..8da9e1c7a 100644 --- a/torchx/pipelines/kfp/adapter.py +++ b/torchx/pipelines/kfp/adapter.py @@ -7,7 +7,7 @@ import copy import os -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, List, Optional, Protocol, Tuple, Type import yaml from kfp import components, dsl @@ -115,7 +115,7 @@ def output(self) -> dsl.PipelineParam: ... -def component_spec_from_app(app: api.Application) -> str: +def component_spec_from_app(app: api.Application) -> Tuple[str, api.Resource]: assert len(app.roles) == 1, f"KFP adapter only support one role, got {app.roles}" role = app.roles[0] @@ -126,9 +126,6 @@ def component_spec_from_app(app: api.Application) -> str: container = role.container assert container.base_image is None, "KFP adapter does not support base_image" - assert ( - container.resources == api.NULL_RESOURCE - ), "KFP adapter requires you to specify resources in the pipeline" assert len(container.port_map) == 0, "KFP adapter does not support port_map" command = [role.entrypoint, *role.args] @@ -144,10 +141,36 @@ def component_spec_from_app(app: api.Application) -> str: } }, } - return yaml.dump(spec) + return yaml.dump(spec), container.resources -# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. -def component_from_app(app: api.Application) -> Callable: - spec = component_spec_from_app(app) - return components.load_component_from_text(spec) +class ContainerFactory(Protocol): + def __call__(self, *args: object, **kwargs: object) -> dsl.ContainerOp: + ... + + +def component_from_app(app: api.Application) -> ContainerFactory: + resources: api.Resource + spec, resources = component_spec_from_app(app) + + assert ( + len(resources.capabilities) == 0 + ), f"KFP doesn't support capabilities, got {resources.capabilities}" + component_factory: ContainerFactory = components.load_component_from_text(spec) + + def factory_wrapper(*args: object, **kwargs: object) -> dsl.ContainerOp: + c = component_factory(*args, **kwargs) + container = c.container + if (cpu := resources.cpu) >= 0: + cpu_str = f"{int(cpu*1000)}m" + container.set_cpu_request(cpu_str) + container.set_cpu_limit(cpu_str) + if (mem := resources.memMB) >= 0: + mem_str = f"{int(mem)}M" + container.set_memory_request(mem_str) + container.set_memory_limit(mem_str) + if (gpu := resources.gpu) >= 0: + container.set_gpu_limit(str(gpu)) + return c + + return factory_wrapper diff --git a/torchx/pipelines/kfp/test/adapter_test.py b/torchx/pipelines/kfp/test/adapter_test.py index 42909b7bf..4e2180d87 100644 --- a/torchx/pipelines/kfp/test/adapter_test.py +++ b/torchx/pipelines/kfp/test/adapter_test.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, TypedDict from kfp import compiler, components, dsl +from kubernetes.client.models import V1ResourceRequirements from torchx.apps.io.copy import Copy from torchx.pipelines.kfp.adapter import ( TorchXComponent, @@ -117,7 +118,14 @@ class KFPSpecsTest(unittest.TestCase): """ def _test_app(self) -> api.Application: - container = api.Container(image="pytorch/torchx:latest") + container = api.Container( + image="pytorch/torchx:latest", + resources=api.Resource( + cpu=2, + memMB=3000, + gpu=4, + ), + ) trainer_role = ( api.Role(name="trainer") .runs( @@ -135,8 +143,9 @@ def _test_app(self) -> api.Application: def test_component_spec_from_app(self) -> None: app = self._test_app() - spec = component_spec_from_app(app) + spec, resources = component_spec_from_app(app) self.assertIsNotNone(components.load_component_from_text(spec)) + self.assertEqual(resources, app.roles[0].container.resources) self.assertEqual( spec, """description: KFP wrapper for TorchX component test, role trainer @@ -160,6 +169,22 @@ def test_pipeline(self) -> None: def pipeline() -> dsl.PipelineParam: a = kfp_copy() + resources: V1ResourceRequirements = a.container.resources + self.assertEqual( + resources, + V1ResourceRequirements( + limits={ + "cpu": "2000m", + "memory": "3000M", + "nvidia.com/gpu": "4", + }, + requests={ + "cpu": "2000m", + "memory": "3000M", + }, + ), + ) + b = kfp_copy() b.after(a) return b