Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def _apply_nested(self, d: typing.Dict[str, Any]) -> typing.Dict[str, Any]:
current_dict[k] = self.substitute(v)
elif isinstance(v, list):
for i in range(len(v)):
if isinstance(v[i], str):
if isinstance(v[i], dict):
stack.append(v[i])
elif isinstance(v[i], str):
v[i] = self.substitute(v[i])
return d

Expand Down
120 changes: 120 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,3 +945,123 @@ def test_apply(self) -> None:
self.assertNotEqual(newrole, role)
self.assertEqual(newrole.args, ["img_root"])
self.assertEqual(newrole.env, {"FOO": "app_id"})

def test_apply_nested_with_list_of_dicts(self) -> None:
"""Test that _apply_nested correctly handles dictionaries nested inside lists."""
role = Role(
name="test",
image="test_image",
entrypoint="foo.py",
metadata={
"nested_list": [
{"key1": macros.app_id, "key2": "static"},
{"key3": macros.img_root},
]
},
)
v = macros.Values(
img_root="img_root_value",
app_id="app_id_value",
replica_id="replica_id_value",
base_img_root="base_img_root_value",
rank0_env="rank0_env_value",
)
newrole = v.apply(role)
self.assertEqual(newrole.metadata["nested_list"][0]["key1"], "app_id_value")
self.assertEqual(newrole.metadata["nested_list"][0]["key2"], "static")
self.assertEqual(newrole.metadata["nested_list"][1]["key3"], "img_root_value")

def test_apply_nested_with_deeply_nested_structures(self) -> None:
"""Test that _apply_nested handles deeply nested structures with mixed types."""
role = Role(
name="test",
image="test_image",
entrypoint="foo.py",
metadata={
"level1": {
"level2": {
"list_with_dicts": [
{
"nested_key": macros.replica_id,
"nested_list": [macros.app_id, "static_value"],
},
{"another_key": macros.img_root},
],
"simple_string": macros.rank0_env,
}
}
},
)
v = macros.Values(
img_root="img_root_value",
app_id="app_id_value",
replica_id="replica_id_value",
base_img_root="base_img_root_value",
rank0_env="rank0_env_value",
)
newrole = v.apply(role)

# Check deeply nested dict in list
nested_dict = newrole.metadata["level1"]["level2"]["list_with_dicts"][0]
self.assertEqual(nested_dict["nested_key"], "replica_id_value")
self.assertEqual(nested_dict["nested_list"][0], "app_id_value")
self.assertEqual(nested_dict["nested_list"][1], "static_value")

# Check second dict in list
second_dict = newrole.metadata["level1"]["level2"]["list_with_dicts"][1]
self.assertEqual(second_dict["another_key"], "img_root_value")

# Check simple string at nested level
self.assertEqual(
newrole.metadata["level1"]["level2"]["simple_string"], "rank0_env_value"
)

def test_apply_nested_with_list_of_strings(self) -> None:
"""Test that _apply_nested still works correctly with lists of strings."""
role = Role(
name="test",
image="test_image",
entrypoint="foo.py",
metadata={
"string_list": [macros.app_id, macros.img_root, "static"],
},
)
v = macros.Values(
img_root="img_root_value",
app_id="app_id_value",
replica_id="replica_id_value",
base_img_root="base_img_root_value",
rank0_env="rank0_env_value",
)
newrole = v.apply(role)
self.assertEqual(newrole.metadata["string_list"][0], "app_id_value")
self.assertEqual(newrole.metadata["string_list"][1], "img_root_value")
self.assertEqual(newrole.metadata["string_list"][2], "static")

def test_apply_nested_with_mixed_list_types(self) -> None:
"""Test that _apply_nested handles lists with mixed types (strings, dicts, other)."""
role = Role(
name="test",
image="test_image",
entrypoint="foo.py",
metadata={
"mixed_list": [
macros.app_id,
{"nested": macros.img_root},
42, # non-string, non-dict value
"static_string",
],
},
)
v = macros.Values(
img_root="img_root_value",
app_id="app_id_value",
replica_id="replica_id_value",
base_img_root="base_img_root_value",
rank0_env="rank0_env_value",
)
newrole = v.apply(role)
self.assertEqual(newrole.metadata["mixed_list"][0], "app_id_value")
self.assertEqual(newrole.metadata["mixed_list"][1]["nested"], "img_root_value")
self.assertEqual(newrole.metadata["mixed_list"][2], 42)
self.assertEqual(newrole.metadata["mixed_list"][3], "static_string")