Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,44 @@ def get_serializable_array_node_map_task(
# TODO Add support for other flyte entities
entity = node.flyte_entity
task_spec = get_serializable(entity_mapping, settings, entity, options)

override_pod_spec = {}
if node._pod_template is not None:
# get_container (not _get_container) goes through prepare_target() so the
# container args carry the map-task command rather than pyflyte-execute.
# When the underlying task has its own pod_template, get_container returns
# None; fall back to the inner python_function_task._get_container under
# prepare_target() so we still have an image/command to merge into the
# override pod spec.
container = entity.get_container(settings)
if container is None and isinstance(entity, (MapPythonTask, ArrayNodeMapTask)):
inner = getattr(entity, "python_function_task", None) or getattr(entity, "_run_task", None)
if inner is not None and hasattr(inner, "_get_container"):
with entity.prepare_target():
container = inner._get_container(settings)
if settings.should_fast_serialize() and container is not None:
# Mirror get_serializable_task: prefix args post-build rather than
# swapping command_fn, since _fast_serialize_command_fn would wrap
# the inherited pyflyte-execute default (wrong for map_task).
container._args = prefix_with_fast_execute(settings, container.args)
override_pod_spec = _serialize_pod_spec(node._pod_template, container, settings)

task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(
resources=node._resources,
extended_resources=node._extended_resources,
container_image=node._container_image,
pod_template=PodTemplate(
pod_spec=override_pod_spec,
labels=node._pod_template.labels if node._pod_template.labels else None,
annotations=node._pod_template.annotations if node._pod_template.annotations else None,
primary_container_name=node._pod_template.primary_container_name
if node._pod_template.primary_container_name
else None,
)
if node._pod_template
else None,
),
)
node = workflow_model.Node(
Expand Down
125 changes: 121 additions & 4 deletions tests/flytekit/unit/test_translator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import typing
from collections import OrderedDict

import pytest
from kubernetes import client
from kubernetes.client import V1Container, V1PodSpec

import flytekit.configuration
from flytekit import ContainerTask, Resources, PodTemplate
from flytekit import ContainerTask, PodTemplate, Resources, map_task
from flytekit.configuration import FastSerializationSettings, Image, ImageConfig
from flytekit.core.base_task import kwtypes
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
Expand All @@ -13,9 +17,6 @@
from flytekit.models.core import identifier as identifier_models
from flytekit.models.task import Resources as resource_model
from flytekit.tools.translator import get_serializable
from kubernetes import client
from kubernetes.client import V1PodSpec, V1Container
import pytest

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -234,3 +235,119 @@ def wf():
assert len(pod_template_override.pod_spec['containers']) == 1
container = pod_template_override.pod_spec['containers'][0]
assert container['env'] == [{'name': 'MY_KEY', 'value': 'MY_VALUE'}]


@pytest.mark.parametrize(
"fast_registration_enabled",
[
pytest.param(True, id="fast registration enabled"),
pytest.param(False, id="fast registration disabled"),
],
)
def test_map_task_with_pod_template_override(fast_registration_enabled: bool):
# Regression test for https://github.com/flyteorg/flyte/issues/7076
# map_task(...).with_overrides(pod_template=...) was silently dropped at serialization.
custom_pod_template = PodTemplate(
primary_container_name="primary",
labels={"lKeyA": "lValA"},
annotations={"aKeyA": "aValA"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary",
env=[client.V1EnvVar(name="MY_KEY", value="MY_VALUE")],
)
]
),
)

@task
def t(a: int) -> str:
Comment on lines +264 to +265
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add another test that t has a podTemplate config and we use with_overrides() to override it?

 @task(pod_template=base_pod_template)
    def t(a: int) -> str:
        return str(a)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 9dde146 — added test_map_task_with_pod_template_override_replaces_task_pod_template covering the case where @task(pod_template=base_pod_template) is then overridden via map_task(t)(...).with_overrides(pod_template=override_pod_template).

While writing the test I caught a real bug: when the underlying task already has pod_template set, entity.get_container(settings) returns None (it intentionally defers to get_k8s_pod), and _serialize_pod_spec would dereference primary_container.image on None. Fixed in the same commit by falling back to the inner task's _get_container under prepare_target() so the override pod spec still resolves image/command/args. Test asserts the override's OVERRIDE_KEY env wins over the base.

return str(a)

@workflow
def wf(xs: typing.List[int]):
map_task(t)(a=xs).with_overrides(pod_template=custom_pod_template)

settings = (
serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled))
.build()
)

wf_spec = get_serializable(OrderedDict(), settings, wf)
assert len(wf_spec.template.nodes) == 1
node = wf_spec.template.nodes[0]
# map_task is serialized as an array_node wrapping an inner task node
assert node.array_node is not None
inner_task_node = node.array_node.node.task_node
assert inner_task_node is not None
assert inner_task_node.overrides.pod_template is not None
pod_template_override = inner_task_node.overrides.pod_template
assert pod_template_override.primary_container_name == "primary"
assert pod_template_override.labels == {"lKeyA": "lValA"}
assert pod_template_override.annotations == {"aKeyA": "aValA"}
assert pod_template_override.pod_spec # validate not empty
assert len(pod_template_override.pod_spec["containers"]) == 1
container = pod_template_override.pod_spec["containers"][0]
assert {"name": "MY_KEY", "value": "MY_VALUE"} in container["env"]


@pytest.mark.parametrize(
"fast_registration_enabled",
[
pytest.param(True, id="fast registration enabled"),
pytest.param(False, id="fast registration disabled"),
],
)
def test_map_task_with_pod_template_override_replaces_task_pod_template(fast_registration_enabled: bool):
# Ensure with_overrides(pod_template=...) on a map_task whose underlying task
# already declares a pod_template still surfaces the override at the array_node.
base_pod_template = PodTemplate(
primary_container_name="primary",
labels={"lKeyBase": "lValBase"},
annotations={"aKeyBase": "aValBase"},
)
override_pod_template = PodTemplate(
primary_container_name="primary",
labels={"lKeyOverride": "lValOverride"},
annotations={"aKeyOverride": "aValOverride"},
pod_spec=V1PodSpec(
containers=[
V1Container(
name="primary",
env=[client.V1EnvVar(name="OVERRIDE_KEY", value="OVERRIDE_VALUE")],
)
]
),
)

@task(pod_template=base_pod_template)
def t(a: int) -> str:
return str(a)

@workflow
def wf(xs: typing.List[int]):
map_task(t)(a=xs).with_overrides(pod_template=override_pod_template)

settings = (
serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=fast_registration_enabled))
.build()
)

wf_spec = get_serializable(OrderedDict(), settings, wf)
assert len(wf_spec.template.nodes) == 1
node = wf_spec.template.nodes[0]
assert node.array_node is not None
inner_task_node = node.array_node.node.task_node
assert inner_task_node is not None
assert inner_task_node.overrides.pod_template is not None
pod_template_override = inner_task_node.overrides.pod_template
assert pod_template_override.primary_container_name == "primary"
assert pod_template_override.labels == {"lKeyOverride": "lValOverride"}
assert pod_template_override.annotations == {"aKeyOverride": "aValOverride"}
assert pod_template_override.pod_spec # validate not empty
assert len(pod_template_override.pod_spec["containers"]) == 1
container = pod_template_override.pod_spec["containers"][0]
assert {"name": "OVERRIDE_KEY", "value": "OVERRIDE_VALUE"} in container["env"]