Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add ParallelFor compile test over single artifact. #10531

Merged
merged 2 commits into from Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Expand Up @@ -838,6 +838,20 @@ def my_pipeline():
with dsl.ParallelFor(items=single_param_task.output) as item:
print_and_return(text=item)

def test_cannot_compile_parallel_for_with_single_artifact(self):

with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single artifact using `dsl\.ParallelFor`\. Expected a list of artifacts as argument to `items`\.'
):

@dsl.pipeline
def my_pipeline():
single_artifact_task = print_and_return_as_artifact(
text='string')
with dsl.ParallelFor(items=single_artifact_task.output) as item:
print_artifact(a=item)

def test_pipeline_in_pipeline(self):

@dsl.pipeline(name='graph-component')
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/for_loop.py
Expand Up @@ -286,7 +286,7 @@ def from_pipeline_channel(
object."""
if not channel.is_artifact_list:
raise ValueError(
'Cannot iterate over a single Artifact using `dsl.ParallelFor`. Expected a list of Artifacts as argument to `items`.'
'Cannot iterate over a single artifact using `dsl.ParallelFor`. Expected a list of artifacts as argument to `items`.'
)
return LoopArtifactArgument(
items=channel,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/for_loop_test.py
Expand Up @@ -207,7 +207,7 @@ def test_loop_artifact_argument_from_single_pipeline_channel_raises_error(
self, channel):
with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single Artifact using `dsl\.ParallelFor`\. Expected a list of Artifacts as argument to `items`\.'
r'Cannot iterate over a single artifact using `dsl\.ParallelFor`\. Expected a list of artifacts as argument to `items`\.'
):
loop_argument = for_loop.LoopArtifactArgument.from_pipeline_channel(
channel)
Expand Down