Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xls/dev_tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ cc_binary(
"//xls/estimators/delay_model/models",
"//xls/interpreter:ir_interpreter",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:channel",
"//xls/ir:channel_ops",
"//xls/ir:events",
Expand Down
251 changes: 247 additions & 4 deletions xls/dev_tools/ir_minimizer_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "xls/dev_tools/extract_segment.h"
#include "xls/dev_tools/extract_state_element.h"
#include "xls/interpreter/function_interpreter.h"
#include "xls/ir/bits.h"
#include "xls/ir/call_graph.h"
#include "xls/ir/channel.h"
#include "xls/ir/channel_ops.h"
Expand All @@ -70,6 +71,7 @@
#include "xls/ir/package.h"
#include "xls/ir/source_location.h"
#include "xls/ir/state_element.h"
#include "xls/ir/topo_sort.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"
#include "xls/ir/value_utils.h"
Expand Down Expand Up @@ -254,6 +256,8 @@ ABSL_FLAG(std::string, output_path, "-",
"Path where the minimized IR will be written. Will print to stdout "
"if not set or set to '-'. If stopped early, will contain the latest "
"known-failing IR.");
ABSL_FLAG(bool, can_minimize_bitwidth, true,
"Whether to minimize bitwidth of nodes.");

namespace xls {
namespace {
Expand Down Expand Up @@ -832,9 +836,238 @@ bool IsInterfaceChannelRef(ChannelRef channel_ref, Package* package) {
std::get<ChannelInterface*>(channel_ref)) != interface.end();
}

struct Slice {
int64_t start;
int64_t new_width;
};

// Returns the largest bit-slice that would encompass all bit-slices of users of
// the node. Returns nullopt if there are no users, if any user isn't a
// bitslice, or if the encompassing slice width is no smaller than the node.
absl::StatusOr<std::optional<Slice>> GetSmallerEncompassingBitSlice(
Node* node, int64_t orig_width) {
if (node->users().empty()) {
return std::nullopt;
}

int64_t min_start = orig_width;
int64_t max_end = 0;
for (Node* user : node->users()) {
if (user->op() != Op::kBitSlice) {
return std::nullopt;
}
BitSlice* slice = user->As<BitSlice>();
min_start = std::min(min_start, slice->start());
max_end = std::max(max_end, slice->start() + slice->width());
}

int64_t new_width = max_end - min_start;
XLS_RET_CHECK(new_width <= orig_width);
if (new_width == orig_width) {
return std::nullopt;
}
return Slice{.start = min_start, .new_width = new_width};
}

// Replaces all users of `old_node` with `new_node`, where new node is assumed
// to start at the bit `new_start` of `old_node` and have width `new_width`.
absl::Status ReplaceBitSliceUsers(Node* old_node, Node* new_node,
int64_t new_start, int64_t new_width) {
FunctionBase* f = old_node->function_base();
std::vector<Node*> orig_users(old_node->users().begin(),
old_node->users().end());
for (Node* user : orig_users) {
XLS_RET_CHECK(user->Is<BitSlice>());
BitSlice* user_slice = user->As<BitSlice>();
Node* new_node_or_slice = new_node;
if (user_slice->width() != new_width || user_slice->start() != new_start) {
int64_t new_user_start = user_slice->start() - new_start;
XLS_RET_CHECK(user_slice->width() + new_user_start <= new_width);
XLS_ASSIGN_OR_RETURN(
new_node_or_slice,
f->MakeNode<BitSlice>(user_slice->loc(), new_node, new_user_start,
user_slice->width()));
}

XLS_RETURN_IF_ERROR(user_slice->ReplaceUsesWith(new_node_or_slice));
XLS_RETURN_IF_ERROR(f->RemoveNode(user_slice));
}
return absl::OkStatus();
}

// Replaces `old_node` with `new_node` preserving ID and param index if the
// node is a function parameter.
absl::Status ReplaceNodeInPlace(Node* old_node, Node* new_node,
int64_t orig_width, int64_t new_width,
std::optional<int64_t> param_index,
std::string_view transform_name) {
FunctionBase* f = old_node->function_base();
int64_t id = old_node->id();
XLS_RETURN_IF_ERROR(old_node->ReplaceImplicitUsesWith(new_node).status());
XLS_RETURN_IF_ERROR(f->RemoveNode(old_node));
new_node->SetId(id);
if (param_index) {
XLS_RETURN_IF_ERROR(
f->MoveParamToIndex(new_node->As<Param>(), *param_index));
}
LOG(INFO) << absl::StrFormat("Trying %s: %s (%d -> %d)", transform_name,
new_node->GetName(), orig_width, new_width);
return absl::OkStatus();
}

struct ClonedNodeInfo {
Node* cloned_node;
std::optional<int64_t> param_index;
};

// Slices node operands and clones the node to use the reduced width operands.
absl::StatusOr<ClonedNodeInfo> SliceOperandsAndCloneNode(Node* node,
int64_t start,
int64_t new_width) {
// Assert that all operands are bits and have sufficient width.
for (Node* operand : node->operands()) {
XLS_RET_CHECK(operand->GetType()->IsBits());
XLS_RET_CHECK(operand->BitCountOrDie() >= start + new_width);
}

FunctionBase* f = node->function_base();
std::vector<Node*> sliced_operands;
sliced_operands.reserve(node->operand_count());
for (Node* operand : node->operands()) {
XLS_ASSIGN_OR_RETURN(
Node * slice,
f->MakeNode<BitSlice>(node->loc(), operand, start, new_width));
sliced_operands.push_back(slice);
}

Node* cloned_node = nullptr;
std::optional<int64_t> param_index;
if (node->Is<Param>()) {
XLS_ASSIGN_OR_RETURN(param_index, f->GetParamIndex(node->As<Param>()));
XLS_ASSIGN_OR_RETURN(cloned_node,
f->MakeNodeWithName<Param>(
node->loc(), f->package()->GetBitsType(new_width),
node->GetName()));
} else if (node->Is<Literal>()) {
Bits sliced_bits =
node->As<Literal>()->value().bits().Slice(start, new_width);
XLS_ASSIGN_OR_RETURN(cloned_node,
f->MakeNode<Literal>(node->loc(), Value(sliced_bits)));
} else if (node->Is<ArithOp>()) {
// NOTE: this is necessary because Clone() assumes the width of arith ops
// will not change.
XLS_ASSIGN_OR_RETURN(
cloned_node, f->MakeNodeWithName<ArithOp>(
node->loc(), sliced_operands[0], sliced_operands[1],
new_width, node->op(), node->GetName()));
} else if (node->Is<PartialProductOp>()) {
XLS_ASSIGN_OR_RETURN(
cloned_node, f->MakeNodeWithName<PartialProductOp>(
node->loc(), sliced_operands[0], sliced_operands[1],
new_width, node->op(), node->GetName()));
} else {
XLS_ASSIGN_OR_RETURN(cloned_node, node->Clone(sliced_operands));
}

return ClonedNodeInfo{cloned_node, param_index};
}

absl::StatusOr<SimplificationResult> TrimBitsOfNode(Node* node,
absl::BitGenRef rng) {
if (!node->GetType()->IsBits() && !node->Is<PartialProductOp>()) {
return SimplificationResult::kDidNotChange;
}

Op op = node->op();
if (op == Op::kParam && !absl::GetFlag(FLAGS_can_remove_params)) {
return SimplificationResult::kDidNotChange;
}
if (!node->Is<Literal>() && !node->Is<Param>() && !node->Is<ArithOp>() &&
!node->Is<CompareOp>() && !node->Is<BinOp>() && !node->Is<NaryOp>() &&
!node->Is<UnOp>() && !node->Is<PartialProductOp>()) {
return SimplificationResult::kDidNotChange;
}

int64_t orig_width;
if (node->GetType()->IsBits()) {
orig_width = node->BitCountOrDie();
} else if (node->Is<PartialProductOp>()) {
orig_width = 0;
for (Node* operand : node->operands()) {
orig_width = std::max(orig_width, operand->BitCountOrDie());
}
} else {
return SimplificationResult::kDidNotChange;
}

if (orig_width <= 1) {
return SimplificationResult::kDidNotChange;
}

// If there are no users, we trim to some randomly chosen smaller width.
if (node->users().empty()) {
int64_t new_width = absl::Uniform<int64_t>(rng, 1, orig_width);
int64_t start = absl::Uniform<int64_t>(rng, 0, orig_width - new_width + 1);
XLS_ASSIGN_OR_RETURN(ClonedNodeInfo clone_info,
SliceOperandsAndCloneNode(node, start, new_width));
XLS_RETURN_IF_ERROR(ReplaceNodeInPlace(
node, clone_info.cloned_node, orig_width, new_width,
clone_info.param_index, "trim node bitwidth (0 users)"));
return SimplificationResult::kDidChange;
}

// If all users are bitslices where they do not span the entire width of the
// literal, then we can trim the literal to the encompassing slice.
XLS_ASSIGN_OR_RETURN(std::optional<Slice> slice,
GetSmallerEncompassingBitSlice(node, orig_width));
if (!slice) {
return SimplificationResult::kDidNotChange;
}

XLS_ASSIGN_OR_RETURN(
ClonedNodeInfo clone_info,
SliceOperandsAndCloneNode(node, slice->start, slice->new_width));
XLS_RETURN_IF_ERROR(ReplaceBitSliceUsers(node, clone_info.cloned_node,
slice->start, slice->new_width));
XLS_RETURN_IF_ERROR(ReplaceNodeInPlace(
node, clone_info.cloned_node, orig_width, slice->new_width,
clone_info.param_index, "trim node bitwidth (bitslice users)"));
return SimplificationResult::kDidChange;
}

bool AreAllUsersBitslices(Node* node) {
for (Node* user : node->users()) {
if (!user->Is<BitSlice>()) {
return false;
}
}
return true;
}

absl::StatusOr<SimplificationResult> TrimBitsOfNodes(FunctionBase* f,
absl::BitGenRef rng) {
XLS_ASSIGN_OR_RETURN(std::vector<Node*> reverse_topo_nodes,
ReverseTopoSort(f, rng));
bool any_changes = false;
for (Node* node : reverse_topo_nodes) {
if (AreAllUsersBitslices(node)) {
XLS_ASSIGN_OR_RETURN(SimplificationResult res, TrimBitsOfNode(node, rng));
if (res == SimplificationResult::kDidChange) {
any_changes = true;
}
}
}

return any_changes ? SimplificationResult::kDidChange
: SimplificationResult::kDidNotChange;
}

constexpr double kSmallishChance = 0.3;

absl::StatusOr<SimplificationResult> SimplifyNode(
Node* n, absl::BitGenRef rng, std::string* which_transform) {
FunctionBase* f = n->function_base();

if (n->Is<Param>() && absl::GetFlag(FLAGS_can_remove_params) &&
absl::GetFlag(FLAGS_can_extract_segments) &&
n->function_base()->IsFunction() &&
Expand Down Expand Up @@ -868,7 +1101,7 @@ absl::StatusOr<SimplificationResult> SimplifyNode(
}
if (((n->Is<Receive>() && absl::GetFlag(FLAGS_can_remove_receives)) ||
(n->Is<Send>() && absl::GetFlag(FLAGS_can_remove_sends))) &&
absl::Bernoulli(rng, 0.3)) {
absl::Bernoulli(rng, kSmallishChance)) {
XLS_ASSIGN_OR_RETURN(ChannelRef c, n->As<ChannelNode>()->GetChannelRef());
absl::flat_hash_set<std::string> preserved_channels;
for (const std::string& chan : absl::GetFlag(FLAGS_preserve_channels)) {
Expand Down Expand Up @@ -931,7 +1164,7 @@ absl::StatusOr<SimplificationResult> SimplifyNode(
}

if (n->Is<Assert>() && absl::GetFlag(FLAGS_can_remove_asserts) &&
absl::Bernoulli(rng, 0.3)) {
absl::Bernoulli(rng, kSmallishChance)) {
*which_transform = "remove assert: " + n->GetName();
XLS_RETURN_IF_ERROR(n->ReplaceUsesWith(n->As<Assert>()->token()));
XLS_RETURN_IF_ERROR(f->RemoveNode(n));
Expand All @@ -943,13 +1176,13 @@ absl::StatusOr<SimplificationResult> SimplifyNode(
}

if (OpIsSideEffecting(n->op()) && !n->Is<Param>() && n->IsDead() &&
absl::Bernoulli(rng, 0.3)) {
absl::Bernoulli(rng, kSmallishChance)) {
*which_transform = "remove userless side-effecting node: " + n->GetName();
XLS_RETURN_IF_ERROR(f->RemoveNode(n));
return SimplificationResult::kDidChange;
}

if (!n->operands().empty() && absl::Bernoulli(rng, 0.3)) {
if (!n->operands().empty() && absl::Bernoulli(rng, kSmallishChance)) {
// Try to replace a node with one of its (potentially truncated/extended)
// operands.
int64_t operand_no = absl::Uniform<int64_t>(rng, 0, n->operand_count());
Expand Down Expand Up @@ -1288,6 +1521,16 @@ absl::StatusOr<SimplifiedIr> Simplify(FunctionBase* f,
}
}

// We do not support modifying the invoke sight for all callers of `f`, so
// we require the function is unused (i.e. the top-level function).
if (f->IsTop() && absl::GetFlag(FLAGS_can_minimize_bitwidth) &&
absl::Bernoulli(rng, kSmallishChance)) {
XLS_ASSIGN_OR_RETURN(SimplificationResult result, TrimBitsOfNodes(f, rng));
if (result == SimplificationResult::kDidChange) {
return in_place(result);
}
}

// Pick a random node and try to do something with it.
int64_t i = absl::Uniform<int64_t>(rng, 0, f->node_count());
Node* n = *std::next(f->nodes().begin(), i);
Expand Down
Loading
Loading