From 74832d3cb1289f28d530d380abbaf68c4d9ee915 Mon Sep 17 00:00:00 2001 From: Angelo Matni Date: Thu, 21 May 2026 10:00:51 -0700 Subject: [PATCH] IR minimizer bitwidth minimization in reverse topo order on bit-typed ops This reverse-topo sorted pass on the IR takes bitslices of the operands of a node to reduce the node's bitwidth. This approach does not add extends by relying on the node's users having introduced bitslices on their operands, meaning we can shrink the node by dropping or modifying bitslices. PiperOrigin-RevId: 919110993 --- xls/dev_tools/BUILD | 1 + xls/dev_tools/ir_minimizer_main.cc | 251 +++++++++++++++++++++++- xls/dev_tools/ir_minimizer_main_test.py | 159 +++++++++++++++ 3 files changed, 407 insertions(+), 4 deletions(-) diff --git a/xls/dev_tools/BUILD b/xls/dev_tools/BUILD index 84ebe40fd5..616f23f4f6 100644 --- a/xls/dev_tools/BUILD +++ b/xls/dev_tools/BUILD @@ -554,6 +554,7 @@ cc_binary( "//xls/estimators/delay_model/models", "//xls/interpreter:ir_interpreter", "//xls/ir", + "//xls/ir:bits", "//xls/ir:channel", "//xls/ir:channel_ops", "//xls/ir:events", diff --git a/xls/dev_tools/ir_minimizer_main.cc b/xls/dev_tools/ir_minimizer_main.cc index e3d971fe36..c5be226519 100644 --- a/xls/dev_tools/ir_minimizer_main.cc +++ b/xls/dev_tools/ir_minimizer_main.cc @@ -56,6 +56,7 @@ #include "xls/dev_tools/extract_segment.h" #include "xls/dev_tools/extract_state_element.h" #include "xls/interpreter/function_interpreter.h" +#include "xls/ir/bits.h" #include "xls/ir/call_graph.h" #include "xls/ir/channel.h" #include "xls/ir/channel_ops.h" @@ -70,6 +71,7 @@ #include "xls/ir/package.h" #include "xls/ir/source_location.h" #include "xls/ir/state_element.h" +#include "xls/ir/topo_sort.h" #include "xls/ir/type.h" #include "xls/ir/value.h" #include "xls/ir/value_utils.h" @@ -254,6 +256,8 @@ ABSL_FLAG(std::string, output_path, "-", "Path where the minimized IR will be written. Will print to stdout " "if not set or set to '-'. If stopped early, will contain the latest " "known-failing IR."); +ABSL_FLAG(bool, can_minimize_bitwidth, true, + "Whether to minimize bitwidth of nodes."); namespace xls { namespace { @@ -832,9 +836,238 @@ bool IsInterfaceChannelRef(ChannelRef channel_ref, Package* package) { std::get(channel_ref)) != interface.end(); } +struct Slice { + int64_t start; + int64_t new_width; +}; + +// Returns the largest bit-slice that would encompass all bit-slices of users of +// the node. Returns nullopt if there are no users, if any user isn't a +// bitslice, or if the encompassing slice width is no smaller than the node. +absl::StatusOr> GetSmallerEncompassingBitSlice( + Node* node, int64_t orig_width) { + if (node->users().empty()) { + return std::nullopt; + } + + int64_t min_start = orig_width; + int64_t max_end = 0; + for (Node* user : node->users()) { + if (user->op() != Op::kBitSlice) { + return std::nullopt; + } + BitSlice* slice = user->As(); + min_start = std::min(min_start, slice->start()); + max_end = std::max(max_end, slice->start() + slice->width()); + } + + int64_t new_width = max_end - min_start; + XLS_RET_CHECK(new_width <= orig_width); + if (new_width == orig_width) { + return std::nullopt; + } + return Slice{.start = min_start, .new_width = new_width}; +} + +// Replaces all users of `old_node` with `new_node`, where new node is assumed +// to start at the bit `new_start` of `old_node` and have width `new_width`. +absl::Status ReplaceBitSliceUsers(Node* old_node, Node* new_node, + int64_t new_start, int64_t new_width) { + FunctionBase* f = old_node->function_base(); + std::vector orig_users(old_node->users().begin(), + old_node->users().end()); + for (Node* user : orig_users) { + XLS_RET_CHECK(user->Is()); + BitSlice* user_slice = user->As(); + Node* new_node_or_slice = new_node; + if (user_slice->width() != new_width || user_slice->start() != new_start) { + int64_t new_user_start = user_slice->start() - new_start; + XLS_RET_CHECK(user_slice->width() + new_user_start <= new_width); + XLS_ASSIGN_OR_RETURN( + new_node_or_slice, + f->MakeNode(user_slice->loc(), new_node, new_user_start, + user_slice->width())); + } + + XLS_RETURN_IF_ERROR(user_slice->ReplaceUsesWith(new_node_or_slice)); + XLS_RETURN_IF_ERROR(f->RemoveNode(user_slice)); + } + return absl::OkStatus(); +} + +// Replaces `old_node` with `new_node` preserving ID and param index if the +// node is a function parameter. +absl::Status ReplaceNodeInPlace(Node* old_node, Node* new_node, + int64_t orig_width, int64_t new_width, + std::optional param_index, + std::string_view transform_name) { + FunctionBase* f = old_node->function_base(); + int64_t id = old_node->id(); + XLS_RETURN_IF_ERROR(old_node->ReplaceImplicitUsesWith(new_node).status()); + XLS_RETURN_IF_ERROR(f->RemoveNode(old_node)); + new_node->SetId(id); + if (param_index) { + XLS_RETURN_IF_ERROR( + f->MoveParamToIndex(new_node->As(), *param_index)); + } + LOG(INFO) << absl::StrFormat("Trying %s: %s (%d -> %d)", transform_name, + new_node->GetName(), orig_width, new_width); + return absl::OkStatus(); +} + +struct ClonedNodeInfo { + Node* cloned_node; + std::optional param_index; +}; + +// Slices node operands and clones the node to use the reduced width operands. +absl::StatusOr SliceOperandsAndCloneNode(Node* node, + int64_t start, + int64_t new_width) { + // Assert that all operands are bits and have sufficient width. + for (Node* operand : node->operands()) { + XLS_RET_CHECK(operand->GetType()->IsBits()); + XLS_RET_CHECK(operand->BitCountOrDie() >= start + new_width); + } + + FunctionBase* f = node->function_base(); + std::vector sliced_operands; + sliced_operands.reserve(node->operand_count()); + for (Node* operand : node->operands()) { + XLS_ASSIGN_OR_RETURN( + Node * slice, + f->MakeNode(node->loc(), operand, start, new_width)); + sliced_operands.push_back(slice); + } + + Node* cloned_node = nullptr; + std::optional param_index; + if (node->Is()) { + XLS_ASSIGN_OR_RETURN(param_index, f->GetParamIndex(node->As())); + XLS_ASSIGN_OR_RETURN(cloned_node, + f->MakeNodeWithName( + node->loc(), f->package()->GetBitsType(new_width), + node->GetName())); + } else if (node->Is()) { + Bits sliced_bits = + node->As()->value().bits().Slice(start, new_width); + XLS_ASSIGN_OR_RETURN(cloned_node, + f->MakeNode(node->loc(), Value(sliced_bits))); + } else if (node->Is()) { + // NOTE: this is necessary because Clone() assumes the width of arith ops + // will not change. + XLS_ASSIGN_OR_RETURN( + cloned_node, f->MakeNodeWithName( + node->loc(), sliced_operands[0], sliced_operands[1], + new_width, node->op(), node->GetName())); + } else if (node->Is()) { + XLS_ASSIGN_OR_RETURN( + cloned_node, f->MakeNodeWithName( + node->loc(), sliced_operands[0], sliced_operands[1], + new_width, node->op(), node->GetName())); + } else { + XLS_ASSIGN_OR_RETURN(cloned_node, node->Clone(sliced_operands)); + } + + return ClonedNodeInfo{cloned_node, param_index}; +} + +absl::StatusOr TrimBitsOfNode(Node* node, + absl::BitGenRef rng) { + if (!node->GetType()->IsBits() && !node->Is()) { + return SimplificationResult::kDidNotChange; + } + + Op op = node->op(); + if (op == Op::kParam && !absl::GetFlag(FLAGS_can_remove_params)) { + return SimplificationResult::kDidNotChange; + } + if (!node->Is() && !node->Is() && !node->Is() && + !node->Is() && !node->Is() && !node->Is() && + !node->Is() && !node->Is()) { + return SimplificationResult::kDidNotChange; + } + + int64_t orig_width; + if (node->GetType()->IsBits()) { + orig_width = node->BitCountOrDie(); + } else if (node->Is()) { + orig_width = 0; + for (Node* operand : node->operands()) { + orig_width = std::max(orig_width, operand->BitCountOrDie()); + } + } else { + return SimplificationResult::kDidNotChange; + } + + if (orig_width <= 1) { + return SimplificationResult::kDidNotChange; + } + + // If there are no users, we trim to some randomly chosen smaller width. + if (node->users().empty()) { + int64_t new_width = absl::Uniform(rng, 1, orig_width); + int64_t start = absl::Uniform(rng, 0, orig_width - new_width + 1); + XLS_ASSIGN_OR_RETURN(ClonedNodeInfo clone_info, + SliceOperandsAndCloneNode(node, start, new_width)); + XLS_RETURN_IF_ERROR(ReplaceNodeInPlace( + node, clone_info.cloned_node, orig_width, new_width, + clone_info.param_index, "trim node bitwidth (0 users)")); + return SimplificationResult::kDidChange; + } + + // If all users are bitslices where they do not span the entire width of the + // literal, then we can trim the literal to the encompassing slice. + XLS_ASSIGN_OR_RETURN(std::optional slice, + GetSmallerEncompassingBitSlice(node, orig_width)); + if (!slice) { + return SimplificationResult::kDidNotChange; + } + + XLS_ASSIGN_OR_RETURN( + ClonedNodeInfo clone_info, + SliceOperandsAndCloneNode(node, slice->start, slice->new_width)); + XLS_RETURN_IF_ERROR(ReplaceBitSliceUsers(node, clone_info.cloned_node, + slice->start, slice->new_width)); + XLS_RETURN_IF_ERROR(ReplaceNodeInPlace( + node, clone_info.cloned_node, orig_width, slice->new_width, + clone_info.param_index, "trim node bitwidth (bitslice users)")); + return SimplificationResult::kDidChange; +} + +bool AreAllUsersBitslices(Node* node) { + for (Node* user : node->users()) { + if (!user->Is()) { + return false; + } + } + return true; +} + +absl::StatusOr TrimBitsOfNodes(FunctionBase* f, + absl::BitGenRef rng) { + XLS_ASSIGN_OR_RETURN(std::vector reverse_topo_nodes, + ReverseTopoSort(f, rng)); + bool any_changes = false; + for (Node* node : reverse_topo_nodes) { + if (AreAllUsersBitslices(node)) { + XLS_ASSIGN_OR_RETURN(SimplificationResult res, TrimBitsOfNode(node, rng)); + if (res == SimplificationResult::kDidChange) { + any_changes = true; + } + } + } + + return any_changes ? SimplificationResult::kDidChange + : SimplificationResult::kDidNotChange; +} + +constexpr double kSmallishChance = 0.3; + absl::StatusOr SimplifyNode( Node* n, absl::BitGenRef rng, std::string* which_transform) { FunctionBase* f = n->function_base(); + if (n->Is() && absl::GetFlag(FLAGS_can_remove_params) && absl::GetFlag(FLAGS_can_extract_segments) && n->function_base()->IsFunction() && @@ -868,7 +1101,7 @@ absl::StatusOr SimplifyNode( } if (((n->Is() && absl::GetFlag(FLAGS_can_remove_receives)) || (n->Is() && absl::GetFlag(FLAGS_can_remove_sends))) && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { XLS_ASSIGN_OR_RETURN(ChannelRef c, n->As()->GetChannelRef()); absl::flat_hash_set preserved_channels; for (const std::string& chan : absl::GetFlag(FLAGS_preserve_channels)) { @@ -931,7 +1164,7 @@ absl::StatusOr SimplifyNode( } if (n->Is() && absl::GetFlag(FLAGS_can_remove_asserts) && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { *which_transform = "remove assert: " + n->GetName(); XLS_RETURN_IF_ERROR(n->ReplaceUsesWith(n->As()->token())); XLS_RETURN_IF_ERROR(f->RemoveNode(n)); @@ -943,13 +1176,13 @@ absl::StatusOr SimplifyNode( } if (OpIsSideEffecting(n->op()) && !n->Is() && n->IsDead() && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { *which_transform = "remove userless side-effecting node: " + n->GetName(); XLS_RETURN_IF_ERROR(f->RemoveNode(n)); return SimplificationResult::kDidChange; } - if (!n->operands().empty() && absl::Bernoulli(rng, 0.3)) { + if (!n->operands().empty() && absl::Bernoulli(rng, kSmallishChance)) { // Try to replace a node with one of its (potentially truncated/extended) // operands. int64_t operand_no = absl::Uniform(rng, 0, n->operand_count()); @@ -1288,6 +1521,16 @@ absl::StatusOr Simplify(FunctionBase* f, } } + // We do not support modifying the invoke sight for all callers of `f`, so + // we require the function is unused (i.e. the top-level function). + if (f->IsTop() && absl::GetFlag(FLAGS_can_minimize_bitwidth) && + absl::Bernoulli(rng, kSmallishChance)) { + XLS_ASSIGN_OR_RETURN(SimplificationResult result, TrimBitsOfNodes(f, rng)); + if (result == SimplificationResult::kDidChange) { + return in_place(result); + } + } + // Pick a random node and try to do something with it. int64_t i = absl::Uniform(rng, 0, f->node_count()); Node* n = *std::next(f->nodes().begin(), i); diff --git a/xls/dev_tools/ir_minimizer_main_test.py b/xls/dev_tools/ir_minimizer_main_test.py index 3abdfffabc..5b1ed4564a 100644 --- a/xls/dev_tools/ir_minimizer_main_test.py +++ b/xls/dev_tools/ir_minimizer_main_test.py @@ -17,6 +17,7 @@ import re import stat import subprocess +import textwrap from typing import Optional from absl.testing import absltest @@ -858,6 +859,7 @@ def test_can_unwrap_map(self): [ IR_MINIMIZER_MAIN_PATH, '--can_remove_params', + '--can_minimize_bitwidth=false', '--can_inline_everything=false', '--test_executable=' + test_sh_file.full_path, ir_file.full_path, @@ -879,6 +881,163 @@ def test_can_unwrap_map(self): self.assertIn('ret literal', minimized_ir) self.assertNotIn('ret invoke', minimized_ir) + def test_minimize_bitwidth(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_bits(x: bits[8] id=2, y: bits[8] id=3, z: bits[8] id=4) -> bits[8] { + add_5: bits[8] = add(x, y, id=5) + ret mul_6: bits[8] = umul(add_5, z, id=6) + }""") + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script checks that add and umul are preserved, and rejects any + # bitwidth narrower than 4 (bits[1], bits[2], bits[3]) or the params being + # dropped outright, forcing the minimizer to reach exactly bits[4]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'add(' $1", + "/usr/bin/env grep 'umul(' $1", + r"/usr/bin/env grep -E 'x\w*:.+y\w*:.+z\w*:' $1", + r"/usr/bin/env grep -q 'bits\[[123]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + '--can_remove_params=true', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 2) + self.assertRegex( + minimized_ir, + r'fn minimize_bits\(x\w*: bits\[4\] id=2, y\w*: bits\[4\] id=3, z\w*:' + r' bits\[4\] id=4\) -> bits\[4\] \{\s*' + r'add\w*: bits\[4\] = add\(x\w*, y\w*.*\s*' + r'ret mul\w*: bits\[4\] = umul\(add\w*, z\w*', + ) + + def test_minimize_bitwidth_no_param_changes(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_bits(x: bits[8] id=2, y: bits[8] id=3, z: bits[8] id=4) -> bits[8] { + add_5: bits[8] = add(x, y, id=5) + ret mul_6: bits[8] = umul(add_5, z, id=6) + }""") + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script checks that add and umul are preserved, and rejects any + # bitwidth narrower than 4 (bits[1], bits[2], bits[3]) or literals replacing + # the parameters usage, forcing the minimizer to reach exactly bits[4]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'add(' $1", + "/usr/bin/env grep 'umul(' $1", + r"/usr/bin/env grep -q 'literal(' $1 && exit 1 || true", + r"/usr/bin/env grep -q 'bits\[[123]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + '--can_remove_params=false', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 5) + self.assertRegex( + minimized_ir, + r'fn minimize_bits\(x\w*: bits\[8\] id=2, y\w*: bits\[8\] id=3, z\w*:' + r' bits\[8\] id=4\) -> bits\[4\] \{\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(x\w*, start=4, width=4.*\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(y\w*, start=4, width=4.*\s*' + r'add\w*: bits\[4\] = add\(bit_slice\.\w*, bit_slice\.\w*.*\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(z\w*, start=4, width=4.*\s*' + r'ret mul\w*: bits\[4\] = umul\(add\w*, bit_slice\.\w*', + ) + + def test_minimize_tuple_of_bits(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_tuple_out() -> (bits[1000], bits[1000]) { + literal.34: bits[1000] = literal(value=0, id=34) + ret smulp.5: (bits[1000], bits[1000]) = smulp(literal.34, literal.34, id=5) + } + """) + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script rejects bitwidths less than 8 (bits[1] through bits[7]), + # forcing the minimizer to reach exactly bits[8]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'smulp(' $1", + "/usr/bin/env grep 'literal' $1", + r"/usr/bin/env grep -q 'bits\[[1-7]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 2) + self.assertRegex( + minimized_ir, + r'top fn minimize_tuple_out\(\) -> \(bits\[8\], bits\[8\]\) \{\s*' + r'literal[\w\.]*: bits\[8\] = literal\(value=0.*\s*' + r'ret smulp[\w\.]*: \(bits\[8\], bits\[8\]\) = smulp\(literal[\w\.]*,' + r' literal[\w\.]*', + ) + if __name__ == '__main__': absltest.main()