Skip to content

Commit

Permalink
cleanup and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
chensun committed Sep 26, 2020
1 parent 4a9ffdd commit 573c9f3
Showing 1 changed file with 14 additions and 48 deletions.
62 changes: 14 additions & 48 deletions sdk/python/kfp/v2/compiler/compiler.py
Expand Up @@ -27,12 +27,12 @@
from kfp import dsl
from kfp.compiler._k8s_helper import sanitize_k8s_name
from kfp.compiler._op_to_template import _op_to_template
#$from ._default_transformers import add_pod_env

from kfp.components.structures import ComponentSpec, InputSpec
from kfp.components._yaml_utils import dump_yaml
from kfp.dsl._metadata import _extract_pipeline_metadata
from kfp.dsl._ops_group import OpsGroup
from kfp.dsl import ir_types
from kfp.ir import pipeline_spec_pb2
from google.protobuf.json_format import MessageToJson

Expand Down Expand Up @@ -290,44 +290,6 @@ def _get_inputs_outputs(
if loop_group.loop_args.name in param.name:
break


# Generate the input/output for recursive opsgroups
# It propagates the recursive opsgroups IO to their ancester opsgroups
def _get_inputs_outputs_recursive_opsgroup(group):
#TODO: refactor the following codes with the above
if group.recursive_ref:
params = [(param, False) for param in group.inputs]
params.extend([(param, True) for param in list(condition_params[group.name])])
for param, is_condition_param in params:
if param.value:
continue
full_name = self._pipelineparam_full_name(param)
if param.op_name:
upstream_op = pipeline.ops[param.op_name]
upstream_groups, downstream_groups = \
self._get_uncommon_ancestors(op_groups, opsgroup_groups, upstream_op, group)
for i, g in enumerate(downstream_groups):
if i == 0:
inputs[g].add((full_name, upstream_groups[0]))
# There is no need to pass the condition param as argument to the downstream ops.
#TODO: this might also apply to ops. add a TODO here and think about it.
elif i == len(downstream_groups) - 1 and is_condition_param:
continue
else:
inputs[g].add((full_name, None))
for i, g in enumerate(upstream_groups):
if i == len(upstream_groups) - 1:
outputs[g].add((full_name, None))
else:
outputs[g].add((full_name, upstream_groups[i+1]))
elif not is_condition_param:
for g in op_groups[group.name]:
inputs[g].add((full_name, None))
for subgroup in group.groups:
_get_inputs_outputs_recursive_opsgroup(subgroup)

_get_inputs_outputs_recursive_opsgroup(root_group)

return inputs, outputs

def _get_dependencies(self, pipeline, root_group, op_groups, opsgroups_groups, opsgroups, condition_params):
Expand Down Expand Up @@ -624,15 +586,17 @@ def _create_dag_templates(self, pipeline, pipeline_spec, op_to_templates_handler
template = self._group_to_dag_template(opsgroups[opsgroup], inputs, outputs, dependencies)
templates.append(template)


deployment_config = pipeline_spec_pb2.PipelineDeploymentConfig()
importer_tasks = []


def _get_input_artifact_type(input_name: str, component_spec: ComponentSpec) -> str:
def _get_input_artifact_type(
input_name: str,
component_spec: ComponentSpec,
) -> str:
"""Find the input artifact type by input name."""
for input in component_spec.inputs:
if input.name == input_name:
from kfp.dsl import ir_types
return ir_types._artifact_types_mapping.get(input.type.lower()) or ''
return ''

Expand Down Expand Up @@ -666,8 +630,12 @@ def _get_input_artifact_type(input_name: str, component_spec: ComponentSpec) ->

return templates

def _create_pipeline_spec(self, args, pipeline) -> pipeline_spec_pb2.PipelineSpec:
"""Create workflow for the pipeline."""

def _create_pipeline_spec(self,
args: List[dsl.PipelineParam],
pipeline: dsl.Pipeline,
) -> pipeline_spec_pb2.PipelineSpec:
"""Create the pipeline spec object."""

pipeline_spec = pipeline_spec_pb2.PipelineSpec()
pipeline_spec.pipeline_info.name = pipeline.name or 'Pipeline'
Expand All @@ -676,7 +644,6 @@ def _create_pipeline_spec(self, args, pipeline) -> pipeline_spec_pb2.PipelineSpe

# Pipeline Parameters
for arg in args:
pipeline_spec.runtime_parameters[arg.name].type = pipeline_spec_pb2.PrimitiveType().PRIMITIVE_TYPE_UNSPECIFIED
if arg.value is not None:
if isinstance(arg.value, int):
pipeline_spec.runtime_parameters[arg.name].type = pipeline_spec_pb2.PrimitiveType().INT
Expand Down Expand Up @@ -733,6 +700,7 @@ def _sanitize_and_inject_artifact(self, pipeline: dsl.Pipeline):
sanitized_ops[sanitized_name] = op
pipeline.ops = sanitized_ops


def _create_pipeline(self,
pipeline_func: Callable,
pipeline_name: Text=None,
Expand Down Expand Up @@ -800,9 +768,7 @@ def _create_pipeline(self,
dsl_pipeline,
)


print('###proto###\n', MessageToJson(pipeline_spec))

#print('###proto###\n', MessageToJson(pipeline_spec))
return pipeline_spec

def compile(self, pipeline_func, package_path, type_check=True):
Expand Down

0 comments on commit 573c9f3

Please sign in to comment.