Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions compiler/back_end/cpp/header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,9 @@ def _extract_switch_arms(expression, ir):
return None, None


def _generate_optimized_ok_method_body(fields, ir, subexpressions):
def _generate_optimized_ok_method_body(
fields, ir, subexpressions, allow_tail_form=False
):
"""Generates optimized C++ code for the Ok() method body.

This function optimizes validation logic for structures with conditional
Expand Down Expand Up @@ -1698,6 +1700,19 @@ def _generate_optimized_ok_method_body(fields, ir, subexpressions):
if total_entries < 2:
group["type"] = "demoted_to_if"

# Find the last surviving switch group; only that one is eligible for
# tail-form case rewrites, since tail-form turns the switch into the
# function's exit and any later block would be unreachable.
last_switch_key = None
if allow_tail_form:
for key in reversed(ordered_keys):
if groups[key]["type"] == "switch":
last_switch_key = key
break
# Any non-switch block after a switch means that switch is
# not the last statement in Ok(); abandon tail-form.
break

blocks = []
for key in ordered_keys:
group = groups[key]
Expand All @@ -1711,6 +1726,7 @@ def _generate_optimized_ok_method_body(fields, ir, subexpressions):
group["known_check_required"] = not _is_discriminant_provably_known(
group["discrim_expr"], fields
)
group["tail_form"] = key == last_switch_key
blocks.append(_emit_switch_block(group))
elif group["type"] == "demoted_to_if":
for field in group["encounter_order"]:
Expand Down Expand Up @@ -1765,7 +1781,8 @@ def _render_case_body(entries):
def _emit_switch_block(group):
"""Emits a complete switch block from a collected switch group.

Performs case-label sorting and identical-body coalescing:
Performs case-label sorting, identical-body coalescing, and (when the
switch is the last statement in Ok()) tail-form arm rewriting:

* Case labels within an arm are sorted by underlying numeric value, so
the C++ compiler is presented with monotonic case sequences and is
Expand All @@ -1775,31 +1792,63 @@ def _emit_switch_block(group):
multiple fields share an existence condition) are merged into a
single multi-label arm, so the compiler emits one body for all of
them rather than duplicating per case.
* When the switch is the function's tail and an arm validates a single
field with no residual, the arm body becomes `return field().Ok();`
instead of `if (!field().Ok()) return false; break;`. Trims one
conditional branch per qualifying arm — adds up across large
tagged-union schemas on Thumb-2 / MicroBlaze.
"""
body_to_labels = {}
body_first_seen = {}
tail_form_enabled = group.get("tail_form", False)

# Group cases by rendered body (for identical-body coalescing). Track
# the original entries so we can decide tail-form eligibility per arm.
body_to_arm = {}
for case_str, case_entry in group["cases_by_label"].items():
body = _render_case_body(case_entry["entries"])
body_to_labels.setdefault(body, []).append((case_entry["sort_key"], case_str))
if body not in body_first_seen:
body_first_seen[body] = case_entry["sort_key"]
arm = body_to_arm.setdefault(
body,
{
"labels": [],
"entries": case_entry["entries"],
"first_sort_key": case_entry["sort_key"],
},
)
arm["labels"].append((case_entry["sort_key"], case_str))
if case_entry["sort_key"] < arm["first_sort_key"]:
arm["first_sort_key"] = case_entry["sort_key"]

arms = []
for body, labels in body_to_labels.items():
labels.sort() # by (sort_key, case_str)
arms.append((body_first_seen[body], labels, body))
arms.sort(key=lambda arm: arm[0])
sorted_arms = sorted(body_to_arm.values(), key=lambda a: a["first_sort_key"])
for arm in sorted_arms:
arm["labels"].sort()

rendered_arms = []
for _, labels, body in arms:
case_labels = "".join(" case {}:\n".format(cs) for _, cs in labels)
rendered_arms.append(
code_template.format_template(
_TEMPLATES.ok_method_switch_arm,
case_labels=case_labels,
case_body=body,
)
for arm in sorted_arms:
case_labels = "".join(
" case {}:\n".format(cs) for _, cs in arm["labels"]
)
# Tail-form is eligible when the arm has a single bare-equality
# entry (one field, no residual). In that case the case body
# collapses to a single Ok() call we can return directly.
if (
tail_form_enabled
and len(arm["entries"]) == 1
and not arm["entries"][0][1] # no residual
):
field = arm["entries"][0][0]
rendered_arms.append(
"{0} return {1}().Ok();\n".format(
case_labels, _cpp_field_name(field.name.name.text)
)
)
else:
body = _render_case_body(arm["entries"])
rendered_arms.append(
code_template.format_template(
_TEMPLATES.ok_method_switch_arm,
case_labels=case_labels,
case_body=body,
)
)

if group.get("known_check_required", True):
known_check = (
Expand Down Expand Up @@ -1979,6 +2028,11 @@ def _generate_structure_definition(type_ir, ir, config: Config):
],
ir,
ok_subexpressions,
# When the final group emitted is a switch and there is nothing
# after it (no [requires] clause), each case body's
# `if (!X().Ok()) return false; break;` can be rewritten as
# `return X().Ok();` — one fewer branch per case.
allow_tail_form=(requires_check == ""),
)

class_forward_declarations = code_template.format_template(
Expand Down
12 changes: 2 additions & 10 deletions testdata/golden_cpp/condition.emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -16544,17 +16544,9 @@ class GenericConditionalInlineView final {
const auto emboss_reserved_switch_discrim = emboss_reserved_local_ok_subexpr_2;
switch (emboss_reserved_switch_discrim.ValueOrDefault()) {
case static_cast</**/::std::int32_t>(0LL):
if (!type_0().Ok()) return false;
break;



return type_0().Ok();
case static_cast</**/::std::int32_t>(1LL):
if (!type_1().Ok()) return false;
break;



return type_1().Ok();

}
}
Expand Down
Loading
Loading