Skip to content

Commit

Permalink
SDK/Compiler - Invoke the op_transformers as early as possible (#1464)
Browse files Browse the repository at this point in the history
* Add reproducible test case

* Invoke the op_transformers as early as possible
  • Loading branch information
kvalev authored and k8s-ci-robot committed Jun 7, 2019
1 parent f8b0638 commit 381083a
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
11 changes: 8 additions & 3 deletions sdk/python/kfp/compiler/compiler.py
Expand Up @@ -453,7 +453,7 @@ def _group_to_template(self, group, inputs, outputs, dependencies):

def _create_templates(self, pipeline, op_transformers=None, op_to_templates_handler=None):
"""Create all groups and ops templates in the pipeline.
Args:
pipeline: Pipeline context object to get all the pipeline data from.
op_transformers: A list of functions that are applied to all ContainerOp instances that are being processed.
Expand All @@ -463,6 +463,13 @@ def _create_templates(self, pipeline, op_transformers=None, op_to_templates_hand
op_to_templates_handler = op_to_templates_handler or (lambda op : [_op_to_template(op)])
new_root_group = pipeline.groups[0]

# Call the transformation functions before determining the inputs/outputs, otherwise
# the user would not be able to use pipeline parameters in the container definition
# (for example as pod labels) - the generated template is invalid.
for op in pipeline.ops.values():
for transformer in op_transformers or []:
transformer(op)

# Generate core data structures to prepare for argo yaml generation
# op_groups: op name -> list of ancestor groups including the current op
# opsgroups: a dictionary of ospgroup.name -> opsgroup
Expand All @@ -486,8 +493,6 @@ def _create_templates(self, pipeline, op_transformers=None, op_to_templates_hand
templates.append(template)

for op in pipeline.ops.values():
for transformer in op_transformers or []:
op = transformer(op) or op
templates.extend(op_to_templates_handler(op))
return templates

Expand Down
6 changes: 5 additions & 1 deletion sdk/python/tests/compiler/compiler_tests.py
Expand Up @@ -366,6 +366,10 @@ def test_py_param_substitutions(self):
"""Test pipeline param_substitutions."""
self._test_py_compile_yaml('param_substitutions')

def test_py_param_op_transform(self):
"""Test pipeline param_op_transform."""
self._test_py_compile_yaml('param_op_transform')

def test_type_checking_with_consistent_types(self):
"""Test type check pipeline parameters against component metadata."""
@component
Expand Down Expand Up @@ -471,7 +475,7 @@ def op():
def pipeline():
task1 = op()
task2 = op().after(task1)

compiler.Compiler()._compile(pipeline)

def _test_op_to_template_yaml(self, ops, file_base_name):
Expand Down
28 changes: 28 additions & 0 deletions sdk/python/tests/compiler/testdata/param_op_transform.py
@@ -0,0 +1,28 @@
from typing import Callable

import kfp.dsl as dsl

def add_common_labels(param):

def _add_common_labels(op: dsl.ContainerOp) -> dsl.ContainerOp:
return op.add_pod_label('param', param)

return _add_common_labels

@dsl.pipeline(
name="Parameters in Op transformation functions",
description="Test that parameters used in Op transformation functions as pod labels "
"would be correcly identified and set as arguments in he generated yaml"
)
def param_substitutions(param = dsl.PipelineParam(name='param')):
dsl.get_pipeline_conf().op_transformers.append(add_common_labels(param))

op = dsl.ContainerOp(
name="cop",
image="image",
)


if __name__ == '__main__':
import kfp.compiler as compiler
compiler.Compiler().compile(param_substitutions, __file__ + '.yaml')
40 changes: 40 additions & 0 deletions sdk/python/tests/compiler/testdata/param_op_transform.yaml
@@ -0,0 +1,40 @@
apiVersion: argoproj.io/v1alpha1
kind: Workflow
metadata:
generateName: parameters-in-op-transformation-functions-
spec:
arguments:
parameters:
- name: param
entrypoint: parameters-in-op-transformation-functions
serviceAccountName: pipeline-runner
templates:
- container:
image: image
inputs:
parameters:
- name: param
metadata:
labels:
param: '{{inputs.parameters.param}}'
name: cop
outputs:
artifacts:
- name: mlpipeline-ui-metadata
optional: true
path: /mlpipeline-ui-metadata.json
- name: mlpipeline-metrics
optional: true
path: /mlpipeline-metrics.json
- dag:
tasks:
- arguments:
parameters:
- name: param
value: '{{inputs.parameters.param}}'
name: cop
template: cop
inputs:
parameters:
- name: param
name: parameters-in-op-transformation-functions

0 comments on commit 381083a

Please sign in to comment.