Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(components): Support dynamic machine type paramters in CustomTrainingJobOp #10883

Merged
merged 12 commits into from
Jun 13, 2024

Conversation

KevinGrantLee
Copy link
Contributor

@KevinGrantLee KevinGrantLee commented Jun 10, 2024

Description of your changes:

Enables setting the machine_type, accelerator_type, and accelerator_count dynamically from task outputs and pipeline inputs in CustomTrainingJobOp.

ex.

@dsl.component
def machine_type() -> str:
    return 'n1-standard-4'


@dsl.component
def accelerator_type() -> str:
    return 'NVIDIA_TESLA_P4'


@dsl.component
def accelerator_count() -> int:
    # This can either be int or int string
    return 1



@dsl.pipeline
def pipeline(
    project: str,
    location: str,
    encryption_spec_key_name: str = '',
):
    machine_type_task = machine_type()
    accelerator_type_task = accelerator_type()
    accelerator_count_task = accelerator_count()

    custom_job.CustomTrainingJobOp(
        display_name='add-numbers',
        worker_pool_specs=[{
            'container_spec': {
                # doesn't need to be the container under test
                # just need an image within the VPC-SC perimeter
                'image_uri':
                    ('gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0'
                    ),
                'command': ['echo'],
                'args': ['foo'],
            },
            'machine_spec': {
                'machine_type': machine_type_task.output,
                'accelerator_type': accelerator_type_task.output,
                'accelerator_count': accelerator_count_task.output,
            },
            'replica_count': 1,
        }],
        project=project,
        location=location,
        encryption_spec_key_name=encryption_spec_key_name,
    )

This PR also enables the following behavior:

@dsl.pipeline
def pipeline(
    project: str,
    location: str,
    machine_type: str,
    accelerator_type: str,
    accelerator_count: int,
    encryption_spec_key_name: str = '',
):

    custom_job.CustomTrainingJobOp(
        display_name='add-numbers',
        worker_pool_specs=[{
            'container_spec': {
                # doesn't need to be the container under test
                # just need an image within the VPC-SC perimeter
                'image_uri':
                    ('gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0'
                    ),
                'command': ['echo'],
                'args': ['foo'],
            },
            'machine_spec': {
                'machine_type': machine_type,
                'accelerator_type': accelerator_type
                'accelerator_count': accelerator_count,
            },
            'replica_count': 1,
        }],
        project=project,
        location=location,
        encryption_spec_key_name=encryption_spec_key_name,
    )

Checklist:

@KevinGrantLee
Copy link
Contributor Author

/retest

1 similar comment
@KevinGrantLee
Copy link
Contributor Author

/retest

@connor-mccarthy
Copy link
Member

@KevinGrantLee, can you please address required presubmit checks DCO and kubeflow-pipelines-sdk-yapf?

@connor-mccarthy
Copy link
Member

/assign @chensun

…iningJobOp

Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
@KevinGrantLee
Copy link
Contributor Author

/retest

Signed-off-by: KevinGrantLee <kglee@google.com>
@KevinGrantLee
Copy link
Contributor Author

/retest


@dsl.component
def accelerator_count() -> int:
# This can either be int or int string
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this true? The type hint doesn't indicate the same.

Copy link
Contributor Author

@KevinGrantLee KevinGrantLee Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also tried compiling and submitting a pipeline with def accelerator_count() -> str: returning '1' and that pipeline succeeded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove the comment to avoid confusion because leaving return annotation as int and returning '1' causes errors

worker_pool_specs=[{
'container_spec': {
# doesn't need to be the container under test
# just need an image within the VPC-SC perimeter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow this comment. Is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this comment from another test pipeline, will remove.

'machine_spec': {
'machine_type': machine_type_task.output,
'accelerator_type': accelerator_type_task.output,
'accelerator_count': accelerator_count_task.output,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make a case where the dynamic value comes from pipeline input parameter? That would make sure we cover all the the dynamic value paths.

Copy link
Contributor Author

@KevinGrantLee KevinGrantLee Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Tested compiling and submitting pipeline.

elif isinstance(data, list):
return [recursive_replace(i, old_value, new_value) for i in data]
else:
if isinstance(data, pipeline_channel.PipelineChannel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method seems explicitly replacing placeholder from one representation to another. It's not for replacing arbitrary value. So the method name should reflect it's purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although here I'm just using this method for replacing placeholders, it can be used for arbitrary values as well so I left the method name general.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed method to recursive_replace_placeholders

3: ['d']
}],
'old_value': 'd',
'new_value': 'dd',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old and new values doesn't seem to be testing the real case scenario?

Copy link
Contributor Author

@KevinGrantLee KevinGrantLee Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That testcase was for some simple dummy values to verify behavior of recursive_replace(), I can remove if you think its redundant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed testcase.

@@ -239,70 +327,18 @@ def build_task_spec_for_task(
component_input_parameter)

elif isinstance(input_value, str):
# Handle extra input due to string concat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, this chunk of code is only applicable for string typed inputs, why merging the code and expand it to other input types? Also it's a bit hard to read the diff between the deleted code and the extracted. Can you try make the changes in place without refactoring, and see if it's actually necessary to expand the logic to non-string typed inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that this block of code could be reused for handling PipelineChannels inside of worker_pool_specs in addition to handling string typed inputs. Instead of copying the ~50 lines of code, I thought it'd be better to refactor the logic as a separate function def replace_and_inject_placeholders().

So I could un-refactor and duplicate the logic; I do have a slight preference for this refactoring but can go either way. wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name of the extracted method isn't accurate--the code does more than placeholder manipulation but also component input expansion.
The branch logic now reads like this:

if isinstance(input_value, str):
    # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.string_value = input_value
elif isinstance(input_value, (int, float, bool, dict, list)):
    if isinstance(input_value, (dict, list):
          # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.CopyFrom(
                    to_protobuf_value(input_value))
else:
     raise

You can achieve the same goal, and even more code reuse, without extracting a shared method by:

if not isinstance(input_value, (str, dict, list, int, float, bool)):
    raise

if isinstance(input_value, (str, dict, list)):
    # shared code

pipeline_task_spec.inputs.parameters[
            input_name].runtime_value.constant.CopyFrom(
                to_protobuf_value(input_value))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the refactoring part, I wonder what's the case for dict and list? In case CustomTrainingJobOp is used, what's the input_value here?

Copy link
Contributor Author

@KevinGrantLee KevinGrantLee Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If CustomTrainingJobOp is used, then we pass in worker_pool_spec into input_value.

It looks like this with PipelineChannel objects
input_value = [{'container_spec': {'image_uri': 'gcr.io/ml-pipeline/google-cloud-pipeline-components:2.5.0', 'command': ['echo'], 'args': ['foo']}, 'machine_spec': {'machine_type': {{channel:task=machine-type;name=Output;type=String;}}, 'accelerator_type': {{channel:task=accelerator-type;name=Output;type=String;}}, 'accelerator_count': {{channel:task=accelerator-count;name=Output;type=Integer;}}}, 'replica_count': 1}]

Copy link
Member

@chensun chensun Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. So input_value would be of type list in this case. Including dict in the same code path is just for future use cases not entirely necessary at this moment, right? I'm fine to include dict now.

Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
@KevinGrantLee
Copy link
Contributor Author

@chensun I verified that nested dags with dsl.Condition() and custom jobs with dynamic (task output and pipeline inputs) machine parameters compile and run successfully.

Signed-off-by: KevinGrantLee <kglee@google.com>
@@ -239,70 +327,18 @@ def build_task_spec_for_task(
component_input_parameter)

elif isinstance(input_value, str):
# Handle extra input due to string concat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name of the extracted method isn't accurate--the code does more than placeholder manipulation but also component input expansion.
The branch logic now reads like this:

if isinstance(input_value, str):
    # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.string_value = input_value
elif isinstance(input_value, (int, float, bool, dict, list)):
    if isinstance(input_value, (dict, list):
          # shared code
    pipeline_task_spec.inputs.parameters[
                input_name].runtime_value.constant.CopyFrom(
                    to_protobuf_value(input_value))
else:
     raise

You can achieve the same goal, and even more code reuse, without extracting a shared method by:

if not isinstance(input_value, (str, dict, list, int, float, bool)):
    raise

if isinstance(input_value, (str, dict, list)):
    # shared code

pipeline_task_spec.inputs.parameters[
            input_name].runtime_value.constant.CopyFrom(
                to_protobuf_value(input_value))

"""Recursively replaces values in a nested dict/list object.

This method is used to replace PipelineChannel objects with pipeine channel
placeholders in a nested object like worker_pool_specs for custom jobs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipeline channel placeholders -> input parameter placeholder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -239,70 +327,18 @@ def build_task_spec_for_task(
component_input_parameter)

elif isinstance(input_value, str):
# Handle extra input due to string concat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from the refactoring part, I wonder what's the case for dict and list? In case CustomTrainingJobOp is used, what's the input_value here?

@chensun
Copy link
Member

chensun commented Jun 12, 2024

@chensun I verified that nested dags with dsl.Condition() and custom jobs with dynamic (task output and pipeline inputs) machine parameters compile and run successfully.

Can you add a test case?

Signed-off-by: KevinGrantLee <kglee@google.com>
Signed-off-by: KevinGrantLee <kglee@google.com>
@KevinGrantLee
Copy link
Contributor Author

/retest

@KevinGrantLee KevinGrantLee removed the request for review from connor-mccarthy June 12, 2024 18:30
additional_input_name].task_output_parameter.output_parameter_key = (
channel.name)
elif isinstance(input_value, (str, int, float, bool, dict, list)):
if isinstance(input_value, (str, dict, list)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove this if isinstance(input_value, (str, dict, list)): check, extract_pipeline_channels_from_any would return an empty list in case float, int, bool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it fine to remove the inner if check and expand the type annotations?

It would simplify the code but I'm not sure if it makes sense to update the type annotations for extract_pipeline_channels_from_any since ints, floats, and bools can't contain pipeline channels.

payload: Union[PipelineChannel, str, list, tuple, dict]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can update the payload annotation type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


if isinstance(input_value, str):
input_value = input_value.replace(
channel.pattern, additional_input_placeholder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be covered by recursive_replace_placeholders, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string case is not covered by recursive_replace_placeholders as is, I would need to embed string.replace() logic in recursive_replace_placeholders if we wanted to get rid of this ifelse block.

I suppose question is if we want to expose this logic in pipeline_spec_builder:build_task_spec_for_task or compiler_utils:recursive_repalce_placeholders?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. It's up to you whether you want to keep it as-is or remove the ifelse block. I don't have a strong preference.

I'm not sure how this is related to your question on exposing the logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, kept as it.

Signed-off-by: KevinGrantLee <kglee@google.com>
Copy link

@KevinGrantLee: The following test failed, say /retest to rerun all failed tests or /retest-required to rerun all mandatory failed tests:

Test name Commit Details Required Rerun command
kfp-kubernetes-execution-tests 1c70801 link false /test kfp-kubernetes-execution-tests

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository. I understand the commands that are listed here.

Copy link
Member

@chensun chensun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/lgtm
/approve

Thanks, @KevinGrantLee !

@google-oss-prow google-oss-prow bot added the lgtm label Jun 13, 2024
Copy link

[APPROVALNOTIFIER] This PR is APPROVED

This pull-request has been approved by: chensun

The full list of commands accepted by this bot can be found here.

The pull request process is described here

Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@google-oss-prow google-oss-prow bot merged commit b57f9e8 into master Jun 13, 2024
28 of 29 checks passed
@google-oss-prow google-oss-prow bot deleted the dynamic-customtrainingjobop branch June 13, 2024 07:09
@pthieu
Copy link

pthieu commented Jun 14, 2024

@KevinGrantLee is it possible to do the same for worker_pool_specs.container_spec.env? I have a use-case where I am trying to keep a consistent timestamp used for file naming on my custom training op and testing op afterwards (testing op needs the freshly-trained model), so I pass it in as an environment variable.

#10902

@KevinGrantLee
Copy link
Contributor Author

Hi @pthieu, this PR should also enable that usecase - I did some local tests, but can you confirm on your end?

@pthieu
Copy link

pthieu commented Jun 17, 2024

@KevinGrantLee think I'll need to wait for the next release as we use this in our CI/CD pipeline. Do you know what the release schedule is? I can do a test and confirm when it's out.

@KevinGrantLee
Copy link
Contributor Author

@pthieu, we're planning to do a release later this week. cc @chensun

@alexredplanet
Copy link

I believe the same issue is occurring for the ModelBatchPredictOp @pthieu @KevinGrantLee , here is a minimal example:

@dsl.component
def gcs_jsonl_uri(bucket_name: str, file_name: str) -> str:
    return f"gs://{bucket_name}/{file_name}"

@dsl.pipeline
def pipeline(project: str, location: str, model_name: str):

    model = ModelGetOp(project=project, model_name=model_name, location=location)
    gcs_jsonl = gcs_jsonl_uri(bucket_name="example", file_name="example.jsonl")

    batch_predict_job = ModelBatchPredictOp(
        model=model.outputs["model"],
        job_display_name="example",
        gcs_source_uris=[gcs_jsonl.output],
        location=location
    )

producing the same error as your issue @pthieu :
ValueError: Value must be one of the following types: str, int, float, bool, dict, and list. Got: "{{channel:task=gcs-jsonl-uri;name=Output;type=String;}}" of type "<class 'kfp.dsl.pipeline_channel.PipelineParameterChannel'>"

@KevinGrantLee
Copy link
Contributor Author

Hi @alexredplanet , I'm reasonably confident that this pr should also fix your case. Once the next kfp release is done, can you retry?

@alexredplanet
Copy link

Thanks @KevinGrantLee it did indeed fix that issue!

@KevinGrantLee
Copy link
Contributor Author

KevinGrantLee commented Jun 25, 2024

Hi @pthieu , KFP SDK 2.8.0 has been released: you should be able to test #10902 again

@pthieu
Copy link

pthieu commented Jun 26, 2024

@KevinGrantLee looks like getting the latest worked (at least no errors thrown), thanks

Just waiting on a dependency requirement change on the google pipelines package:

The conflict is caused by:
    The user requested kfp==2.8.0
    google-cloud-pipeline-components 2.14.1 depends on kfp<=2.7.0 and >=2.6.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants