diff --git a/xls/passes/strength_reduction_pass.cc b/xls/passes/strength_reduction_pass.cc index cb1e44c9af..4755116888 100644 --- a/xls/passes/strength_reduction_pass.cc +++ b/xls/passes/strength_reduction_pass.cc @@ -50,9 +50,10 @@ namespace { absl::StatusOr MaybeSinkOperationIntoSelect( Node* node, const QueryEngine& query_engine, Select* select_val) { - if (OpIsSideEffecting(node->op())) { + if (OpIsSideEffecting(node->op()) || node->Is()) { // Side-effecting operations are not always safe to duplicate so don't - // bother. + // bother. Invokes are also excluded because the invoked functions may + // contain side-effecting operations. return false; } DCHECK(!query_engine.IsFullyKnown(select_val)); diff --git a/xls/passes/strength_reduction_pass_test.cc b/xls/passes/strength_reduction_pass_test.cc index 583c8e8292..c8f7bc8832 100644 --- a/xls/passes/strength_reduction_pass_test.cc +++ b/xls/passes/strength_reduction_pass_test.cc @@ -503,6 +503,32 @@ TEST_F(StrengthReductionPassTest, DoNotPushDownCheapExtendingOps) { ASSERT_THAT(Run(f), IsOkAndHolds(false)) << f->DumpIr(); } +TEST_F(StrengthReductionPassTest, DoNotPushDownInvokeThroughSelect) { + auto p = CreatePackage(); + Function* callee; + { + FunctionBuilder fb("callee", p.get()); + BValue lhs = fb.Param("lhs", p->GetBitsType(32)); + BValue rhs = fb.Param("rhs", p->GetBitsType(32)); + XLS_ASSERT_OK_AND_ASSIGN(callee, fb.BuildWithReturnValue(fb.Add(lhs, rhs))); + } + + FunctionBuilder fb(TestName(), p.get()); + BValue selector = fb.Param("selector", p->GetBitsType(1)); + BValue rhs = fb.Select(selector, {fb.Literal(UBits(1, 32)), + fb.Literal(UBits(2, 32))}); + fb.Invoke({fb.Literal(UBits(3, 32)), rhs}, callee); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); + ASSERT_THAT(Run(f), IsOkAndHolds(false)) << f->DumpIr(); + + EXPECT_THAT(f->return_value(), + m::Invoke(m::Literal(UBits(3, 32)), + m::Select(m::Param(), + {m::Literal(UBits(1, 32)), + m::Literal(UBits(2, 32))}))) + << f->DumpIr(); +} + // This is something we might want to support at some point. TEST_F(StrengthReductionPassTest, DoNotPushDownMultipleSelects) { auto p = CreatePackage();