From 5e1894c4bf0e9819601761e56e99aa4a360c0146 Mon Sep 17 00:00:00 2001 From: Aaron Webster Date: Tue, 19 May 2026 15:44:40 -0700 Subject: [PATCH] Tail-form return in terminal switches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the optimized Ok() switch is the function's tail — last group emitted, no [requires] clause after — arms that validate exactly one field with no residual collapse from case K: if (!field().Ok()) return false; break; to case K: return field().Ok(); Saves one conditional branch per qualifying arm. On Thumb-2 / MicroBlaze (where -Os is conservative about merging cmp+beq sequences) and on older compilers that don't see through the if/return-false/break shape, this trims a few bytes per case across large tagged-union schemas. Per-arm decision: any arm that doesn't qualify (multi-field, residual, or the switch isn't terminal) keeps the break-form body. Mixing forms within one switch is safe — the default fallthrough lands on the function's \`return true;\` either way. Golden churn: many_conditionals.emb.h shrinks substantially (the 99 single-field cases all collapse to tail-form); condition.emb.h and parameters.emb.h see smaller similar collapses. --- compiler/back_end/cpp/header_generator.py | 94 ++- testdata/golden_cpp/condition.emb.h | 12 +- testdata/golden_cpp/many_conditionals.emb.h | 612 ++++---------------- testdata/golden_cpp/parameters.emb.h | 18 +- 4 files changed, 181 insertions(+), 555 deletions(-) diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py index 0342f92..6d40516 100644 --- a/compiler/back_end/cpp/header_generator.py +++ b/compiler/back_end/cpp/header_generator.py @@ -1554,7 +1554,9 @@ def _extract_switch_arms(expression, ir): return None, None -def _generate_optimized_ok_method_body(fields, ir, subexpressions): +def _generate_optimized_ok_method_body( + fields, ir, subexpressions, allow_tail_form=False +): """Generates optimized C++ code for the Ok() method body. This function optimizes validation logic for structures with conditional @@ -1698,6 +1700,19 @@ def _generate_optimized_ok_method_body(fields, ir, subexpressions): if total_entries < 2: group["type"] = "demoted_to_if" + # Find the last surviving switch group; only that one is eligible for + # tail-form case rewrites, since tail-form turns the switch into the + # function's exit and any later block would be unreachable. + last_switch_key = None + if allow_tail_form: + for key in reversed(ordered_keys): + if groups[key]["type"] == "switch": + last_switch_key = key + break + # Any non-switch block after a switch means that switch is + # not the last statement in Ok(); abandon tail-form. + break + blocks = [] for key in ordered_keys: group = groups[key] @@ -1711,6 +1726,7 @@ def _generate_optimized_ok_method_body(fields, ir, subexpressions): group["known_check_required"] = not _is_discriminant_provably_known( group["discrim_expr"], fields ) + group["tail_form"] = key == last_switch_key blocks.append(_emit_switch_block(group)) elif group["type"] == "demoted_to_if": for field in group["encounter_order"]: @@ -1765,7 +1781,8 @@ def _render_case_body(entries): def _emit_switch_block(group): """Emits a complete switch block from a collected switch group. - Performs case-label sorting and identical-body coalescing: + Performs case-label sorting, identical-body coalescing, and (when the + switch is the last statement in Ok()) tail-form arm rewriting: * Case labels within an arm are sorted by underlying numeric value, so the C++ compiler is presented with monotonic case sequences and is @@ -1775,31 +1792,63 @@ def _emit_switch_block(group): multiple fields share an existence condition) are merged into a single multi-label arm, so the compiler emits one body for all of them rather than duplicating per case. + * When the switch is the function's tail and an arm validates a single + field with no residual, the arm body becomes `return field().Ok();` + instead of `if (!field().Ok()) return false; break;`. Trims one + conditional branch per qualifying arm — adds up across large + tagged-union schemas on Thumb-2 / MicroBlaze. """ - body_to_labels = {} - body_first_seen = {} + tail_form_enabled = group.get("tail_form", False) + + # Group cases by rendered body (for identical-body coalescing). Track + # the original entries so we can decide tail-form eligibility per arm. + body_to_arm = {} for case_str, case_entry in group["cases_by_label"].items(): body = _render_case_body(case_entry["entries"]) - body_to_labels.setdefault(body, []).append((case_entry["sort_key"], case_str)) - if body not in body_first_seen: - body_first_seen[body] = case_entry["sort_key"] + arm = body_to_arm.setdefault( + body, + { + "labels": [], + "entries": case_entry["entries"], + "first_sort_key": case_entry["sort_key"], + }, + ) + arm["labels"].append((case_entry["sort_key"], case_str)) + if case_entry["sort_key"] < arm["first_sort_key"]: + arm["first_sort_key"] = case_entry["sort_key"] - arms = [] - for body, labels in body_to_labels.items(): - labels.sort() # by (sort_key, case_str) - arms.append((body_first_seen[body], labels, body)) - arms.sort(key=lambda arm: arm[0]) + sorted_arms = sorted(body_to_arm.values(), key=lambda a: a["first_sort_key"]) + for arm in sorted_arms: + arm["labels"].sort() rendered_arms = [] - for _, labels, body in arms: - case_labels = "".join(" case {}:\n".format(cs) for _, cs in labels) - rendered_arms.append( - code_template.format_template( - _TEMPLATES.ok_method_switch_arm, - case_labels=case_labels, - case_body=body, - ) + for arm in sorted_arms: + case_labels = "".join( + " case {}:\n".format(cs) for _, cs in arm["labels"] ) + # Tail-form is eligible when the arm has a single bare-equality + # entry (one field, no residual). In that case the case body + # collapses to a single Ok() call we can return directly. + if ( + tail_form_enabled + and len(arm["entries"]) == 1 + and not arm["entries"][0][1] # no residual + ): + field = arm["entries"][0][0] + rendered_arms.append( + "{0} return {1}().Ok();\n".format( + case_labels, _cpp_field_name(field.name.name.text) + ) + ) + else: + body = _render_case_body(arm["entries"]) + rendered_arms.append( + code_template.format_template( + _TEMPLATES.ok_method_switch_arm, + case_labels=case_labels, + case_body=body, + ) + ) if group.get("known_check_required", True): known_check = ( @@ -1979,6 +2028,11 @@ def _generate_structure_definition(type_ir, ir, config: Config): ], ir, ok_subexpressions, + # When the final group emitted is a switch and there is nothing + # after it (no [requires] clause), each case body's + # `if (!X().Ok()) return false; break;` can be rewritten as + # `return X().Ok();` — one fewer branch per case. + allow_tail_form=(requires_check == ""), ) class_forward_declarations = code_template.format_template( diff --git a/testdata/golden_cpp/condition.emb.h b/testdata/golden_cpp/condition.emb.h index a409c97..e8c0c35 100644 --- a/testdata/golden_cpp/condition.emb.h +++ b/testdata/golden_cpp/condition.emb.h @@ -16544,17 +16544,9 @@ class GenericConditionalInlineView final { const auto emboss_reserved_switch_discrim = emboss_reserved_local_ok_subexpr_2; switch (emboss_reserved_switch_discrim.ValueOrDefault()) { case static_cast(0LL): - if (!type_0().Ok()) return false; - break; - - - + return type_0().Ok(); case static_cast(1LL): - if (!type_1().Ok()) return false; - break; - - - + return type_1().Ok(); } } diff --git a/testdata/golden_cpp/many_conditionals.emb.h b/testdata/golden_cpp/many_conditionals.emb.h index 657348c..007fa54 100644 --- a/testdata/golden_cpp/many_conditionals.emb.h +++ b/testdata/golden_cpp/many_conditionals.emb.h @@ -225,599 +225,203 @@ class GenericLargeConditionalsView final { case static_cast(1LL): - if (!f1().Ok()) return false; - break; - - - + return f1().Ok(); case static_cast(2LL): - if (!f2().Ok()) return false; - break; - - - + return f2().Ok(); case static_cast(3LL): - if (!f3().Ok()) return false; - break; - - - + return f3().Ok(); case static_cast(4LL): - if (!f4().Ok()) return false; - break; - - - + return f4().Ok(); case static_cast(5LL): - if (!f5().Ok()) return false; - break; - - - + return f5().Ok(); case static_cast(6LL): - if (!f6().Ok()) return false; - break; - - - + return f6().Ok(); case static_cast(7LL): - if (!f7().Ok()) return false; - break; - - - + return f7().Ok(); case static_cast(8LL): - if (!f8().Ok()) return false; - break; - - - + return f8().Ok(); case static_cast(9LL): - if (!f9().Ok()) return false; - break; - - - + return f9().Ok(); case static_cast(10LL): - if (!f10().Ok()) return false; - break; - - - + return f10().Ok(); case static_cast(11LL): - if (!f11().Ok()) return false; - break; - - - + return f11().Ok(); case static_cast(12LL): - if (!f12().Ok()) return false; - break; - - - + return f12().Ok(); case static_cast(13LL): - if (!f13().Ok()) return false; - break; - - - + return f13().Ok(); case static_cast(14LL): - if (!f14().Ok()) return false; - break; - - - + return f14().Ok(); case static_cast(15LL): - if (!f15().Ok()) return false; - break; - - - + return f15().Ok(); case static_cast(16LL): - if (!f16().Ok()) return false; - break; - - - + return f16().Ok(); case static_cast(17LL): - if (!f17().Ok()) return false; - break; - - - + return f17().Ok(); case static_cast(18LL): - if (!f18().Ok()) return false; - break; - - - + return f18().Ok(); case static_cast(19LL): - if (!f19().Ok()) return false; - break; - - - + return f19().Ok(); case static_cast(20LL): - if (!f20().Ok()) return false; - break; - - - + return f20().Ok(); case static_cast(21LL): - if (!f21().Ok()) return false; - break; - - - + return f21().Ok(); case static_cast(22LL): - if (!f22().Ok()) return false; - break; - - - + return f22().Ok(); case static_cast(23LL): - if (!f23().Ok()) return false; - break; - - - + return f23().Ok(); case static_cast(24LL): - if (!f24().Ok()) return false; - break; - - - + return f24().Ok(); case static_cast(25LL): - if (!f25().Ok()) return false; - break; - - - + return f25().Ok(); case static_cast(26LL): - if (!f26().Ok()) return false; - break; - - - + return f26().Ok(); case static_cast(27LL): - if (!f27().Ok()) return false; - break; - - - + return f27().Ok(); case static_cast(28LL): - if (!f28().Ok()) return false; - break; - - - + return f28().Ok(); case static_cast(29LL): - if (!f29().Ok()) return false; - break; - - - + return f29().Ok(); case static_cast(30LL): - if (!f30().Ok()) return false; - break; - - - + return f30().Ok(); case static_cast(31LL): - if (!f31().Ok()) return false; - break; - - - + return f31().Ok(); case static_cast(32LL): - if (!f32().Ok()) return false; - break; - - - + return f32().Ok(); case static_cast(33LL): - if (!f33().Ok()) return false; - break; - - - + return f33().Ok(); case static_cast(34LL): - if (!f34().Ok()) return false; - break; - - - + return f34().Ok(); case static_cast(35LL): - if (!f35().Ok()) return false; - break; - - - + return f35().Ok(); case static_cast(36LL): - if (!f36().Ok()) return false; - break; - - - + return f36().Ok(); case static_cast(37LL): - if (!f37().Ok()) return false; - break; - - - + return f37().Ok(); case static_cast(38LL): - if (!f38().Ok()) return false; - break; - - - + return f38().Ok(); case static_cast(39LL): - if (!f39().Ok()) return false; - break; - - - + return f39().Ok(); case static_cast(40LL): - if (!f40().Ok()) return false; - break; - - - + return f40().Ok(); case static_cast(41LL): - if (!f41().Ok()) return false; - break; - - - + return f41().Ok(); case static_cast(42LL): - if (!f42().Ok()) return false; - break; - - - + return f42().Ok(); case static_cast(43LL): - if (!f43().Ok()) return false; - break; - - - + return f43().Ok(); case static_cast(44LL): - if (!f44().Ok()) return false; - break; - - - + return f44().Ok(); case static_cast(45LL): - if (!f45().Ok()) return false; - break; - - - + return f45().Ok(); case static_cast(46LL): - if (!f46().Ok()) return false; - break; - - - + return f46().Ok(); case static_cast(47LL): - if (!f47().Ok()) return false; - break; - - - + return f47().Ok(); case static_cast(48LL): - if (!f48().Ok()) return false; - break; - - - + return f48().Ok(); case static_cast(49LL): - if (!f49().Ok()) return false; - break; - - - + return f49().Ok(); case static_cast(50LL): - if (!f50().Ok()) return false; - break; - - - + return f50().Ok(); case static_cast(51LL): - if (!f51().Ok()) return false; - break; - - - + return f51().Ok(); case static_cast(52LL): - if (!f52().Ok()) return false; - break; - - - + return f52().Ok(); case static_cast(53LL): - if (!f53().Ok()) return false; - break; - - - + return f53().Ok(); case static_cast(54LL): - if (!f54().Ok()) return false; - break; - - - + return f54().Ok(); case static_cast(55LL): - if (!f55().Ok()) return false; - break; - - - + return f55().Ok(); case static_cast(56LL): - if (!f56().Ok()) return false; - break; - - - + return f56().Ok(); case static_cast(57LL): - if (!f57().Ok()) return false; - break; - - - + return f57().Ok(); case static_cast(58LL): - if (!f58().Ok()) return false; - break; - - - + return f58().Ok(); case static_cast(59LL): - if (!f59().Ok()) return false; - break; - - - + return f59().Ok(); case static_cast(60LL): - if (!f60().Ok()) return false; - break; - - - + return f60().Ok(); case static_cast(61LL): - if (!f61().Ok()) return false; - break; - - - + return f61().Ok(); case static_cast(62LL): - if (!f62().Ok()) return false; - break; - - - + return f62().Ok(); case static_cast(63LL): - if (!f63().Ok()) return false; - break; - - - + return f63().Ok(); case static_cast(64LL): - if (!f64().Ok()) return false; - break; - - - + return f64().Ok(); case static_cast(65LL): - if (!f65().Ok()) return false; - break; - - - + return f65().Ok(); case static_cast(66LL): - if (!f66().Ok()) return false; - break; - - - + return f66().Ok(); case static_cast(67LL): - if (!f67().Ok()) return false; - break; - - - + return f67().Ok(); case static_cast(68LL): - if (!f68().Ok()) return false; - break; - - - + return f68().Ok(); case static_cast(69LL): - if (!f69().Ok()) return false; - break; - - - + return f69().Ok(); case static_cast(70LL): - if (!f70().Ok()) return false; - break; - - - + return f70().Ok(); case static_cast(71LL): - if (!f71().Ok()) return false; - break; - - - + return f71().Ok(); case static_cast(72LL): - if (!f72().Ok()) return false; - break; - - - + return f72().Ok(); case static_cast(73LL): - if (!f73().Ok()) return false; - break; - - - + return f73().Ok(); case static_cast(74LL): - if (!f74().Ok()) return false; - break; - - - + return f74().Ok(); case static_cast(75LL): - if (!f75().Ok()) return false; - break; - - - + return f75().Ok(); case static_cast(76LL): - if (!f76().Ok()) return false; - break; - - - + return f76().Ok(); case static_cast(77LL): - if (!f77().Ok()) return false; - break; - - - + return f77().Ok(); case static_cast(78LL): - if (!f78().Ok()) return false; - break; - - - + return f78().Ok(); case static_cast(79LL): - if (!f79().Ok()) return false; - break; - - - + return f79().Ok(); case static_cast(80LL): - if (!f80().Ok()) return false; - break; - - - + return f80().Ok(); case static_cast(81LL): - if (!f81().Ok()) return false; - break; - - - + return f81().Ok(); case static_cast(82LL): - if (!f82().Ok()) return false; - break; - - - + return f82().Ok(); case static_cast(83LL): - if (!f83().Ok()) return false; - break; - - - + return f83().Ok(); case static_cast(84LL): - if (!f84().Ok()) return false; - break; - - - + return f84().Ok(); case static_cast(85LL): - if (!f85().Ok()) return false; - break; - - - + return f85().Ok(); case static_cast(86LL): - if (!f86().Ok()) return false; - break; - - - + return f86().Ok(); case static_cast(87LL): - if (!f87().Ok()) return false; - break; - - - + return f87().Ok(); case static_cast(88LL): - if (!f88().Ok()) return false; - break; - - - + return f88().Ok(); case static_cast(89LL): - if (!f89().Ok()) return false; - break; - - - + return f89().Ok(); case static_cast(90LL): - if (!f90().Ok()) return false; - break; - - - + return f90().Ok(); case static_cast(91LL): - if (!f91().Ok()) return false; - break; - - - + return f91().Ok(); case static_cast(92LL): - if (!f92().Ok()) return false; - break; - - - + return f92().Ok(); case static_cast(93LL): - if (!f93().Ok()) return false; - break; - - - + return f93().Ok(); case static_cast(94LL): - if (!f94().Ok()) return false; - break; - - - + return f94().Ok(); case static_cast(95LL): - if (!f95().Ok()) return false; - break; - - - + return f95().Ok(); case static_cast(96LL): - if (!f96().Ok()) return false; - break; - - - + return f96().Ok(); case static_cast(97LL): - if (!f97().Ok()) return false; - break; - - - + return f97().Ok(); case static_cast(98LL): - if (!f98().Ok()) return false; - break; - - - + return f98().Ok(); case static_cast(99LL): - if (!f99().Ok()) return false; - break; - - - + return f99().Ok(); } } @@ -9255,26 +8859,14 @@ class GenericDisjunctionConditionalsView final { case static_cast(0LL): case static_cast(1LL): case static_cast(2LL): - if (!shared_low().Ok()) return false; - break; - - - + return shared_low().Ok(); case static_cast(10LL): case static_cast(11LL): - if (!shared_high().Ok()) return false; - break; - - - + return shared_high().Ok(); case static_cast(100LL): case static_cast(200LL): case static_cast(300LL): - if (!shared_far().Ok()) return false; - break; - - - + return shared_far().Ok(); } } diff --git a/testdata/golden_cpp/parameters.emb.h b/testdata/golden_cpp/parameters.emb.h index 4be3ea8..a11178c 100644 --- a/testdata/golden_cpp/parameters.emb.h +++ b/testdata/golden_cpp/parameters.emb.h @@ -3375,23 +3375,11 @@ if (!parameters_initialized_) return false; if (!emboss_reserved_switch_discrim.Known()) return false; switch (emboss_reserved_switch_discrim.ValueOrDefault()) { case static_cast(1): - if (!x().Ok()) return false; - break; - - - + return x().Ok(); case static_cast(2): - if (!y().Ok()) return false; - break; - - - + return y().Ok(); case static_cast(3): - if (!z().Ok()) return false; - break; - - - + return z().Ok(); } }