From 1e708083a8f831edb6267b6e5e6271d26363c61e Mon Sep 17 00:00:00 2001 From: Hongye Sun Date: Wed, 21 Nov 2018 17:06:31 -0800 Subject: [PATCH] Add explicit set_gpu_limit function --- sdk/python/kfp/dsl/_container_op.py | 23 ++++++++++++++++++++- sdk/python/tests/compiler/testdata/basic.py | 2 +- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/sdk/python/kfp/dsl/_container_op.py b/sdk/python/kfp/dsl/_container_op.py index a8c2c69355f..f74c263b464 100644 --- a/sdk/python/kfp/dsl/_container_op.py +++ b/sdk/python/kfp/dsl/_container_op.py @@ -115,10 +115,13 @@ def _validate_gpu_string(self, gpu_string): "Validate a given string is valid for gpu limit." try: - int(gpu_string) + gpu_value = int(gpu_string) except ValueError: raise ValueError('Invalid gpu string. Should be integer.') + if gpu_value <= 0: + raise ValueError('gpu must be positive integer.') + def add_resource_limit(self, resource_name, value): """Add the resource limit of the container. @@ -182,6 +185,24 @@ def set_cpu_limit(self, cpu): self._validate_cpu_string(cpu) return self.add_resource_limit("cpu", cpu) + def set_gpu_limit(self, gpu, vendor = "nvidia"): + """Set gpu limit for the operator. This function add '.com/gpu' into resource limit. + Note that there is no need to add GPU request. GPUs are only supposed to be specified in + the limits section. See https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/. + + Args: + gpu: A string which must be a positive number. + vendor: Optional. A string which is the vendor of the requested gpu. The supported values + are: 'nvidia' (default), and 'amd'. + """ + + self._validate_gpu_string(gpu) + if vendor != 'nvidia' or vendor != 'amd': + raise ValueError('vendor can only be nvidia or amd.') + + return self.add_resource_limit("%s.com/gpu" % vendor, gpu) + + def add_volume(self, volume): """Add K8s volume to the container diff --git a/sdk/python/tests/compiler/testdata/basic.py b/sdk/python/tests/compiler/testdata/basic.py index 648ebcf56c8..3f078010000 100644 --- a/sdk/python/tests/compiler/testdata/basic.py +++ b/sdk/python/tests/compiler/testdata/basic.py @@ -85,5 +85,5 @@ def save_most_frequent_word(message: dsl.PipelineParam, outputpath: dsl.Pipeline message=counter.output, output_path=outputpath) saver.set_cpu_limit('0.5') - saver.add_resource_limit('nvidia.com/gpu', '2') + saver.set_gpu_limit('2') saver.add_node_selector_constraint('cloud.google.com/gke-accelerator', 'nvidia-tesla-k80')