Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly disallow iteration on Promises #2337

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,18 @@ def wf():
The attribute keys are appended on the promise and a new promise is returned with the updated attribute path.
We don't modify the original promise because it might be used in other places as well.
"""

return self._append_attr(key)

def __iter__(self):
"""
Flyte/kit (as of https://github.com/flyteorg/flyte/issues/3864) supports indexing into a list of promises.
But it still doesn't make sense to
"""
raise ValueError(
"Promise objects are not iterable - can't range() over a promise."
" But you can use [index] or the still stabilizing @eager"
)

def __getattr__(self, key) -> Promise:
"""
When we use . to access the attribute on the promise, for example
Expand Down
36 changes: 36 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from collections import OrderedDict
from typing import List

import pytest

Expand All @@ -15,6 +16,7 @@
from flytekit.core.workflow import workflow
from flytekit.models.literals import LiteralMap
from flytekit.tools.translator import get_serializable_task
from flytekit.types.file import FlyteFile

settings = flytekit.configuration.SerializationSettings(
project="test_proj",
Expand Down Expand Up @@ -290,3 +292,37 @@ def dt(mode: int) -> int:
serialised_entities_iterator = iter(entity_mapping.values())
assert "t1" in next(serialised_entities_iterator).template.id.name
assert "t2" in next(serialised_entities_iterator).template.id.name


def test_iter():
@task(requests=Resources(mem="5Gi"))
def ff_list_task() -> List[FlyteFile]:
return [FlyteFile(path=__file__, remote_path=False), FlyteFile(path=__file__, remote_path=False)]

@workflow
def sub_wf(input_file: FlyteFile) -> FlyteFile:
return input_file

@dynamic(requests=Resources(mem="5Gi"))
def dynamic_task() -> List[FlyteFile]:
batched_input_files = ff_list_task()
result_files: List[FlyteFile] = []

for _ in batched_input_files:
...

return result_files

with context_manager.FlyteContextManager.with_context(
context_manager.FlyteContextManager.current_context().with_serialization_settings(settings)
) as ctx:
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(
mode=ExecutionState.Mode.TASK_EXECUTION,
)
)
) as ctx:
input_literal_map = TypeEngine.dict_to_literal_map(ctx, {})
with pytest.raises(ValueError):
dynamic_task.dispatch_execute(ctx, input_literal_map)
Loading