Skip to content

Commit

Permalink
Add explicit set_gpu_limit function
Browse files Browse the repository at this point in the history
  • Loading branch information
hongye-sun committed Nov 22, 2018
1 parent 3e85555 commit 1e70808
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
23 changes: 22 additions & 1 deletion sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,13 @@ def _validate_gpu_string(self, gpu_string):
"Validate a given string is valid for gpu limit."

try:
int(gpu_string)
gpu_value = int(gpu_string)
except ValueError:
raise ValueError('Invalid gpu string. Should be integer.')

if gpu_value <= 0:
raise ValueError('gpu must be positive integer.')

def add_resource_limit(self, resource_name, value):
"""Add the resource limit of the container.
Expand Down Expand Up @@ -182,6 +185,24 @@ def set_cpu_limit(self, cpu):
self._validate_cpu_string(cpu)
return self.add_resource_limit("cpu", cpu)

def set_gpu_limit(self, gpu, vendor = "nvidia"):
"""Set gpu limit for the operator. This function add '<vendor>.com/gpu' into resource limit.
Note that there is no need to add GPU request. GPUs are only supposed to be specified in
the limits section. See https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/.
Args:
gpu: A string which must be a positive number.
vendor: Optional. A string which is the vendor of the requested gpu. The supported values
are: 'nvidia' (default), and 'amd'.
"""

self._validate_gpu_string(gpu)
if vendor != 'nvidia' or vendor != 'amd':
raise ValueError('vendor can only be nvidia or amd.')

return self.add_resource_limit("%s.com/gpu" % vendor, gpu)


def add_volume(self, volume):
"""Add K8s volume to the container
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/tests/compiler/testdata/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ def save_most_frequent_word(message: dsl.PipelineParam, outputpath: dsl.Pipeline
message=counter.output,
output_path=outputpath)
saver.set_cpu_limit('0.5')
saver.add_resource_limit('nvidia.com/gpu', '2')
saver.set_gpu_limit('2')
saver.add_node_selector_constraint('cloud.google.com/gke-accelerator', 'nvidia-tesla-k80')

0 comments on commit 1e70808

Please sign in to comment.