diff --git a/xls/dev_tools/BUILD b/xls/dev_tools/BUILD index 84ebe40fd5..616f23f4f6 100644 --- a/xls/dev_tools/BUILD +++ b/xls/dev_tools/BUILD @@ -554,6 +554,7 @@ cc_binary( "//xls/estimators/delay_model/models", "//xls/interpreter:ir_interpreter", "//xls/ir", + "//xls/ir:bits", "//xls/ir:channel", "//xls/ir:channel_ops", "//xls/ir:events", diff --git a/xls/dev_tools/ir_minimizer_main.cc b/xls/dev_tools/ir_minimizer_main.cc index e3d971fe36..c5be226519 100644 --- a/xls/dev_tools/ir_minimizer_main.cc +++ b/xls/dev_tools/ir_minimizer_main.cc @@ -56,6 +56,7 @@ #include "xls/dev_tools/extract_segment.h" #include "xls/dev_tools/extract_state_element.h" #include "xls/interpreter/function_interpreter.h" +#include "xls/ir/bits.h" #include "xls/ir/call_graph.h" #include "xls/ir/channel.h" #include "xls/ir/channel_ops.h" @@ -70,6 +71,7 @@ #include "xls/ir/package.h" #include "xls/ir/source_location.h" #include "xls/ir/state_element.h" +#include "xls/ir/topo_sort.h" #include "xls/ir/type.h" #include "xls/ir/value.h" #include "xls/ir/value_utils.h" @@ -254,6 +256,8 @@ ABSL_FLAG(std::string, output_path, "-", "Path where the minimized IR will be written. Will print to stdout " "if not set or set to '-'. If stopped early, will contain the latest " "known-failing IR."); +ABSL_FLAG(bool, can_minimize_bitwidth, true, + "Whether to minimize bitwidth of nodes."); namespace xls { namespace { @@ -832,9 +836,238 @@ bool IsInterfaceChannelRef(ChannelRef channel_ref, Package* package) { std::get(channel_ref)) != interface.end(); } +struct Slice { + int64_t start; + int64_t new_width; +}; + +// Returns the largest bit-slice that would encompass all bit-slices of users of +// the node. Returns nullopt if there are no users, if any user isn't a +// bitslice, or if the encompassing slice width is no smaller than the node. +absl::StatusOr> GetSmallerEncompassingBitSlice( + Node* node, int64_t orig_width) { + if (node->users().empty()) { + return std::nullopt; + } + + int64_t min_start = orig_width; + int64_t max_end = 0; + for (Node* user : node->users()) { + if (user->op() != Op::kBitSlice) { + return std::nullopt; + } + BitSlice* slice = user->As(); + min_start = std::min(min_start, slice->start()); + max_end = std::max(max_end, slice->start() + slice->width()); + } + + int64_t new_width = max_end - min_start; + XLS_RET_CHECK(new_width <= orig_width); + if (new_width == orig_width) { + return std::nullopt; + } + return Slice{.start = min_start, .new_width = new_width}; +} + +// Replaces all users of `old_node` with `new_node`, where new node is assumed +// to start at the bit `new_start` of `old_node` and have width `new_width`. +absl::Status ReplaceBitSliceUsers(Node* old_node, Node* new_node, + int64_t new_start, int64_t new_width) { + FunctionBase* f = old_node->function_base(); + std::vector orig_users(old_node->users().begin(), + old_node->users().end()); + for (Node* user : orig_users) { + XLS_RET_CHECK(user->Is()); + BitSlice* user_slice = user->As(); + Node* new_node_or_slice = new_node; + if (user_slice->width() != new_width || user_slice->start() != new_start) { + int64_t new_user_start = user_slice->start() - new_start; + XLS_RET_CHECK(user_slice->width() + new_user_start <= new_width); + XLS_ASSIGN_OR_RETURN( + new_node_or_slice, + f->MakeNode(user_slice->loc(), new_node, new_user_start, + user_slice->width())); + } + + XLS_RETURN_IF_ERROR(user_slice->ReplaceUsesWith(new_node_or_slice)); + XLS_RETURN_IF_ERROR(f->RemoveNode(user_slice)); + } + return absl::OkStatus(); +} + +// Replaces `old_node` with `new_node` preserving ID and param index if the +// node is a function parameter. +absl::Status ReplaceNodeInPlace(Node* old_node, Node* new_node, + int64_t orig_width, int64_t new_width, + std::optional param_index, + std::string_view transform_name) { + FunctionBase* f = old_node->function_base(); + int64_t id = old_node->id(); + XLS_RETURN_IF_ERROR(old_node->ReplaceImplicitUsesWith(new_node).status()); + XLS_RETURN_IF_ERROR(f->RemoveNode(old_node)); + new_node->SetId(id); + if (param_index) { + XLS_RETURN_IF_ERROR( + f->MoveParamToIndex(new_node->As(), *param_index)); + } + LOG(INFO) << absl::StrFormat("Trying %s: %s (%d -> %d)", transform_name, + new_node->GetName(), orig_width, new_width); + return absl::OkStatus(); +} + +struct ClonedNodeInfo { + Node* cloned_node; + std::optional param_index; +}; + +// Slices node operands and clones the node to use the reduced width operands. +absl::StatusOr SliceOperandsAndCloneNode(Node* node, + int64_t start, + int64_t new_width) { + // Assert that all operands are bits and have sufficient width. + for (Node* operand : node->operands()) { + XLS_RET_CHECK(operand->GetType()->IsBits()); + XLS_RET_CHECK(operand->BitCountOrDie() >= start + new_width); + } + + FunctionBase* f = node->function_base(); + std::vector sliced_operands; + sliced_operands.reserve(node->operand_count()); + for (Node* operand : node->operands()) { + XLS_ASSIGN_OR_RETURN( + Node * slice, + f->MakeNode(node->loc(), operand, start, new_width)); + sliced_operands.push_back(slice); + } + + Node* cloned_node = nullptr; + std::optional param_index; + if (node->Is()) { + XLS_ASSIGN_OR_RETURN(param_index, f->GetParamIndex(node->As())); + XLS_ASSIGN_OR_RETURN(cloned_node, + f->MakeNodeWithName( + node->loc(), f->package()->GetBitsType(new_width), + node->GetName())); + } else if (node->Is()) { + Bits sliced_bits = + node->As()->value().bits().Slice(start, new_width); + XLS_ASSIGN_OR_RETURN(cloned_node, + f->MakeNode(node->loc(), Value(sliced_bits))); + } else if (node->Is()) { + // NOTE: this is necessary because Clone() assumes the width of arith ops + // will not change. + XLS_ASSIGN_OR_RETURN( + cloned_node, f->MakeNodeWithName( + node->loc(), sliced_operands[0], sliced_operands[1], + new_width, node->op(), node->GetName())); + } else if (node->Is()) { + XLS_ASSIGN_OR_RETURN( + cloned_node, f->MakeNodeWithName( + node->loc(), sliced_operands[0], sliced_operands[1], + new_width, node->op(), node->GetName())); + } else { + XLS_ASSIGN_OR_RETURN(cloned_node, node->Clone(sliced_operands)); + } + + return ClonedNodeInfo{cloned_node, param_index}; +} + +absl::StatusOr TrimBitsOfNode(Node* node, + absl::BitGenRef rng) { + if (!node->GetType()->IsBits() && !node->Is()) { + return SimplificationResult::kDidNotChange; + } + + Op op = node->op(); + if (op == Op::kParam && !absl::GetFlag(FLAGS_can_remove_params)) { + return SimplificationResult::kDidNotChange; + } + if (!node->Is() && !node->Is() && !node->Is() && + !node->Is() && !node->Is() && !node->Is() && + !node->Is() && !node->Is()) { + return SimplificationResult::kDidNotChange; + } + + int64_t orig_width; + if (node->GetType()->IsBits()) { + orig_width = node->BitCountOrDie(); + } else if (node->Is()) { + orig_width = 0; + for (Node* operand : node->operands()) { + orig_width = std::max(orig_width, operand->BitCountOrDie()); + } + } else { + return SimplificationResult::kDidNotChange; + } + + if (orig_width <= 1) { + return SimplificationResult::kDidNotChange; + } + + // If there are no users, we trim to some randomly chosen smaller width. + if (node->users().empty()) { + int64_t new_width = absl::Uniform(rng, 1, orig_width); + int64_t start = absl::Uniform(rng, 0, orig_width - new_width + 1); + XLS_ASSIGN_OR_RETURN(ClonedNodeInfo clone_info, + SliceOperandsAndCloneNode(node, start, new_width)); + XLS_RETURN_IF_ERROR(ReplaceNodeInPlace( + node, clone_info.cloned_node, orig_width, new_width, + clone_info.param_index, "trim node bitwidth (0 users)")); + return SimplificationResult::kDidChange; + } + + // If all users are bitslices where they do not span the entire width of the + // literal, then we can trim the literal to the encompassing slice. + XLS_ASSIGN_OR_RETURN(std::optional slice, + GetSmallerEncompassingBitSlice(node, orig_width)); + if (!slice) { + return SimplificationResult::kDidNotChange; + } + + XLS_ASSIGN_OR_RETURN( + ClonedNodeInfo clone_info, + SliceOperandsAndCloneNode(node, slice->start, slice->new_width)); + XLS_RETURN_IF_ERROR(ReplaceBitSliceUsers(node, clone_info.cloned_node, + slice->start, slice->new_width)); + XLS_RETURN_IF_ERROR(ReplaceNodeInPlace( + node, clone_info.cloned_node, orig_width, slice->new_width, + clone_info.param_index, "trim node bitwidth (bitslice users)")); + return SimplificationResult::kDidChange; +} + +bool AreAllUsersBitslices(Node* node) { + for (Node* user : node->users()) { + if (!user->Is()) { + return false; + } + } + return true; +} + +absl::StatusOr TrimBitsOfNodes(FunctionBase* f, + absl::BitGenRef rng) { + XLS_ASSIGN_OR_RETURN(std::vector reverse_topo_nodes, + ReverseTopoSort(f, rng)); + bool any_changes = false; + for (Node* node : reverse_topo_nodes) { + if (AreAllUsersBitslices(node)) { + XLS_ASSIGN_OR_RETURN(SimplificationResult res, TrimBitsOfNode(node, rng)); + if (res == SimplificationResult::kDidChange) { + any_changes = true; + } + } + } + + return any_changes ? SimplificationResult::kDidChange + : SimplificationResult::kDidNotChange; +} + +constexpr double kSmallishChance = 0.3; + absl::StatusOr SimplifyNode( Node* n, absl::BitGenRef rng, std::string* which_transform) { FunctionBase* f = n->function_base(); + if (n->Is() && absl::GetFlag(FLAGS_can_remove_params) && absl::GetFlag(FLAGS_can_extract_segments) && n->function_base()->IsFunction() && @@ -868,7 +1101,7 @@ absl::StatusOr SimplifyNode( } if (((n->Is() && absl::GetFlag(FLAGS_can_remove_receives)) || (n->Is() && absl::GetFlag(FLAGS_can_remove_sends))) && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { XLS_ASSIGN_OR_RETURN(ChannelRef c, n->As()->GetChannelRef()); absl::flat_hash_set preserved_channels; for (const std::string& chan : absl::GetFlag(FLAGS_preserve_channels)) { @@ -931,7 +1164,7 @@ absl::StatusOr SimplifyNode( } if (n->Is() && absl::GetFlag(FLAGS_can_remove_asserts) && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { *which_transform = "remove assert: " + n->GetName(); XLS_RETURN_IF_ERROR(n->ReplaceUsesWith(n->As()->token())); XLS_RETURN_IF_ERROR(f->RemoveNode(n)); @@ -943,13 +1176,13 @@ absl::StatusOr SimplifyNode( } if (OpIsSideEffecting(n->op()) && !n->Is() && n->IsDead() && - absl::Bernoulli(rng, 0.3)) { + absl::Bernoulli(rng, kSmallishChance)) { *which_transform = "remove userless side-effecting node: " + n->GetName(); XLS_RETURN_IF_ERROR(f->RemoveNode(n)); return SimplificationResult::kDidChange; } - if (!n->operands().empty() && absl::Bernoulli(rng, 0.3)) { + if (!n->operands().empty() && absl::Bernoulli(rng, kSmallishChance)) { // Try to replace a node with one of its (potentially truncated/extended) // operands. int64_t operand_no = absl::Uniform(rng, 0, n->operand_count()); @@ -1288,6 +1521,16 @@ absl::StatusOr Simplify(FunctionBase* f, } } + // We do not support modifying the invoke sight for all callers of `f`, so + // we require the function is unused (i.e. the top-level function). + if (f->IsTop() && absl::GetFlag(FLAGS_can_minimize_bitwidth) && + absl::Bernoulli(rng, kSmallishChance)) { + XLS_ASSIGN_OR_RETURN(SimplificationResult result, TrimBitsOfNodes(f, rng)); + if (result == SimplificationResult::kDidChange) { + return in_place(result); + } + } + // Pick a random node and try to do something with it. int64_t i = absl::Uniform(rng, 0, f->node_count()); Node* n = *std::next(f->nodes().begin(), i); diff --git a/xls/dev_tools/ir_minimizer_main_test.py b/xls/dev_tools/ir_minimizer_main_test.py index 3abdfffabc..5b1ed4564a 100644 --- a/xls/dev_tools/ir_minimizer_main_test.py +++ b/xls/dev_tools/ir_minimizer_main_test.py @@ -17,6 +17,7 @@ import re import stat import subprocess +import textwrap from typing import Optional from absl.testing import absltest @@ -858,6 +859,7 @@ def test_can_unwrap_map(self): [ IR_MINIMIZER_MAIN_PATH, '--can_remove_params', + '--can_minimize_bitwidth=false', '--can_inline_everything=false', '--test_executable=' + test_sh_file.full_path, ir_file.full_path, @@ -879,6 +881,163 @@ def test_can_unwrap_map(self): self.assertIn('ret literal', minimized_ir) self.assertNotIn('ret invoke', minimized_ir) + def test_minimize_bitwidth(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_bits(x: bits[8] id=2, y: bits[8] id=3, z: bits[8] id=4) -> bits[8] { + add_5: bits[8] = add(x, y, id=5) + ret mul_6: bits[8] = umul(add_5, z, id=6) + }""") + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script checks that add and umul are preserved, and rejects any + # bitwidth narrower than 4 (bits[1], bits[2], bits[3]) or the params being + # dropped outright, forcing the minimizer to reach exactly bits[4]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'add(' $1", + "/usr/bin/env grep 'umul(' $1", + r"/usr/bin/env grep -E 'x\w*:.+y\w*:.+z\w*:' $1", + r"/usr/bin/env grep -q 'bits\[[123]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + '--can_remove_params=true', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 2) + self.assertRegex( + minimized_ir, + r'fn minimize_bits\(x\w*: bits\[4\] id=2, y\w*: bits\[4\] id=3, z\w*:' + r' bits\[4\] id=4\) -> bits\[4\] \{\s*' + r'add\w*: bits\[4\] = add\(x\w*, y\w*.*\s*' + r'ret mul\w*: bits\[4\] = umul\(add\w*, z\w*', + ) + + def test_minimize_bitwidth_no_param_changes(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_bits(x: bits[8] id=2, y: bits[8] id=3, z: bits[8] id=4) -> bits[8] { + add_5: bits[8] = add(x, y, id=5) + ret mul_6: bits[8] = umul(add_5, z, id=6) + }""") + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script checks that add and umul are preserved, and rejects any + # bitwidth narrower than 4 (bits[1], bits[2], bits[3]) or literals replacing + # the parameters usage, forcing the minimizer to reach exactly bits[4]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'add(' $1", + "/usr/bin/env grep 'umul(' $1", + r"/usr/bin/env grep -q 'literal(' $1 && exit 1 || true", + r"/usr/bin/env grep -q 'bits\[[123]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + '--can_remove_params=false', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 5) + self.assertRegex( + minimized_ir, + r'fn minimize_bits\(x\w*: bits\[8\] id=2, y\w*: bits\[8\] id=3, z\w*:' + r' bits\[8\] id=4\) -> bits\[4\] \{\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(x\w*, start=4, width=4.*\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(y\w*, start=4, width=4.*\s*' + r'add\w*: bits\[4\] = add\(bit_slice\.\w*, bit_slice\.\w*.*\s*' + r'bit_slice\.\w*: bits\[4\] = bit_slice\(z\w*, start=4, width=4.*\s*' + r'ret mul\w*: bits\[4\] = umul\(add\w*, bit_slice\.\w*', + ) + + def test_minimize_tuple_of_bits(self): + input_ir = textwrap.dedent(""" + package foo + top fn minimize_tuple_out() -> (bits[1000], bits[1000]) { + literal.34: bits[1000] = literal(value=0, id=34) + ret smulp.5: (bits[1000], bits[1000]) = smulp(literal.34, literal.34, id=5) + } + """) + ir_file = self.create_tempfile(content=input_ir) + test_sh_file = self.create_tempfile() + # The test script rejects bitwidths less than 8 (bits[1] through bits[7]), + # forcing the minimizer to reach exactly bits[8]. + self._write_sh_script( + test_sh_file.full_path, + [ + "/usr/bin/env grep 'smulp(' $1", + "/usr/bin/env grep 'literal' $1", + r"/usr/bin/env grep -q 'bits\[[1-7]\]' $1 && exit 1 || true", + ], + ) + output = subprocess.run( + [ + IR_MINIMIZER_MAIN_PATH, + '--can_minimize_bitwidth=true', + f'--test_executable={test_sh_file.full_path}', + ir_file.full_path, + ], + encoding='utf-8', + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + self.assertEqual( + output.returncode, + 0, + f'Non zero return: stderr {output.stderr!r}, stdout: {output.stdout!r}', + ) + minimized_ir = output.stdout + self._maybe_record_property('output', minimized_ir) + + self.assertEqual(function_count(minimized_ir), 1) + self.assertEqual(node_count(minimized_ir), 2) + self.assertRegex( + minimized_ir, + r'top fn minimize_tuple_out\(\) -> \(bits\[8\], bits\[8\]\) \{\s*' + r'literal[\w\.]*: bits\[8\] = literal\(value=0.*\s*' + r'ret smulp[\w\.]*: \(bits\[8\], bits\[8\]\) = smulp\(literal[\w\.]*,' + r' literal[\w\.]*', + ) + if __name__ == '__main__': absltest.main()