Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions xls/dslx/ir_convert/get_conversion_records.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
records_(records),
processed_invocations_(processed_invocations) {}

absl::Status VisitProcFunctionsWithSeparateTypeInfos(
Module* owner, const Function* config_function,
TypeInfo* config_type_info, const Function* next_function,
TypeInfo* next_type_info) {
// Get conversion records from invocations in this proc's "config"
// "next" functions and add to our list of records. Don't use Accept
// because that will run HandleFunction, which ignores "config" and
// "next" functions.
ConversionRecordVisitor config_visitor(
owner, config_type_info,
include_tests_, proc_id_factory_, top_,
resolved_proc_alias_, records_,
processed_invocations_);
XLS_RETURN_IF_ERROR(config_visitor.DefaultHandler(config_function));

ConversionRecordVisitor next_visitor(
owner, next_type_info,
include_tests_, proc_id_factory_, top_,
resolved_proc_alias_, records_,
processed_invocations_);
return next_visitor.DefaultHandler(next_function);
}

absl::StatusOr<ConversionRecord> SpawnDataToConversionRecord(
const SpawnData& spawn, ProcId proc_id) {
VLOG(5) << "Making conversion record for SpawnData with proc: "
Expand All @@ -83,14 +106,10 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
<< "; config TI: " << std::hex << spawn.config_type_info
<< "; next TI: " << spawn.next_type_info;

ConversionRecordVisitor visitor(spawn.proc->owner(), spawn.next_type_info,
include_tests_, proc_id_factory_, top_,
resolved_proc_alias_, records_,
processed_invocations_);
// Get additional conversion records from invocations in this proc's "next"
// function and add to our list of records. Don't use Accept because that
// will run HandleFunction, which ignores "next" functions.
XLS_RETURN_IF_ERROR(visitor.DefaultHandler(&spawn.proc->next()));
XLS_RETURN_IF_ERROR(VisitProcFunctionsWithSeparateTypeInfos(
spawn.proc->owner(), &spawn.proc->config(),
spawn.config_type_info, &spawn.proc->next(),
spawn.next_type_info));

XLS_ASSIGN_OR_RETURN(
ConversionRecord config_record,
Expand All @@ -114,6 +133,21 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
: *type_info_->GetImportedTypeInfo(node->owner());
}

absl::Status HandleConditional(const Conditional* expr) override {
if (expr->IsConst()) {
XLS_ASSIGN_OR_RETURN(InterpValue test_value,
GetTypeInfo(expr)->GetConstExpr(expr->test()));
if (test_value.IsTrue()) {
XLS_RETURN_IF_ERROR(DefaultHandler(expr->consequent()));
} else {
XLS_RETURN_IF_ERROR(DefaultHandler(ToExprNode(expr->alternate())));
}
return absl::OkStatus();
}

return DefaultHandler(expr);
}

// Generates a conversion record for the given function if it is a real
// function (not parametric or compiler-derived) that has no incoming calls
// known to `type_info_`. Also traverses such functions to ensure that
Expand Down Expand Up @@ -270,17 +304,26 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
absl::Status HandleProc(const Proc* p) override {
VLOG(5) << "HandleProc " << p->ToString();
const Function* next_fn = &p->next();
// Handle any calls inside function bodies.
XLS_RETURN_IF_ERROR(DefaultHandler(next_fn));
// This is required in order to process cross-module spawns; otherwise it
// will never add procs from imported modules to the list of functions to
// convert.
XLS_RETURN_IF_ERROR(DefaultHandler(&p->config()));
// Traversing parametric procs is done later with proper
// type infos for proc's "config" and "next" functions separately.
if (!p->IsParametric()) {
// Handle any calls inside function bodies.
XLS_RETURN_IF_ERROR(DefaultHandler(next_fn));
// This is required in order to process cross-module spawns; otherwise it
// will never add procs from imported modules to the list of functions to
// convert.
XLS_RETURN_IF_ERROR(DefaultHandler(&p->config()));
}

ProcId proc_id = proc_id_factory_.CreateProcId(
/*parent=*/std::nullopt, const_cast<Proc*>(p),
/*count_as_new_instance=*/false);
if (top_ == next_fn && resolved_proc_alias_.has_value()) {
XLS_RETURN_IF_ERROR(VisitProcFunctionsWithSeparateTypeInfos(
top_->owner(), &p->config(),
resolved_proc_alias_->config_type_info, &p->next(),
resolved_proc_alias_->next_type_info));

proc_id.alias_name = resolved_proc_alias_->name;
XLS_ASSIGN_OR_RETURN(
ConversionRecord config_record,
Expand Down
115 changes: 115 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,121 @@ fn main() -> u32 {
ExpectIr(converted);
}

TEST_F(IrConverterTest, ConstConditionalProcScoped) {
constexpr std::string_view program = R"(
proc Multiply {
input: chan<u32> in;
output: chan<u32> out;

init {}

config(input: chan<u32> in, output: chan<u32> out) {
(input, output)
}

next(state: ()) {
let (tok, req) = recv(join(), input);
let data = req * u32:2;
let tok = send(tok, output, data);
}
}

proc Passthrough {
input: chan<u32> in;
output: chan<u32> out;

init {}

config(input: chan<u32> in, output: chan<u32> out) {
(input, output)
}

next(state: ()) {
let (tok, req) = recv(join(), input);
let tok = send(tok, output, req);
}
}

const CONFIG = u32:31;

proc Top {
init {}

config(req_r: chan<u32> in, resp_s: chan<u32> out) {
const if CONFIG <= u32:27 {
spawn Passthrough(req_r, resp_s);
} else {
spawn Multiply(req_r, resp_s);
};
()
}

next(state: ()) { state }
}
)";

ConvertOptions options;
options.lower_to_proc_scoped_channels = true;
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertOneFunctionForTest(program, "Top", options));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ConstConditionalProcScopedWithParams) {
constexpr std::string_view program = R"(
proc Falsy {
req_r: chan<()> in;
resp_s: chan<bool> out;

config(req_r: chan<()> in, resp_s: chan<bool> out) { (req_r, resp_s) }

init { }

next(_: ()) {
let (tok, _d) = recv(join(), req_r);
let tok = send(tok, resp_s, false);
}
}

proc Truthy {
req_r: chan<()> in;
resp_s: chan<bool> out;

config(req_r: chan<()> in, resp_s: chan<bool> out) { (req_r, resp_s) }

init { }

next(_: ()) {
let (tok, _d) = recv(join(), req_r);
let tok = send(tok, resp_s, true);
}
}

proc Foo<CONFIG: bool> {
config(req_r: chan<()> in, resp_s: chan<bool> out) {
const if CONFIG {
spawn Truthy(req_r, resp_s);
} else {
spawn Falsy(req_r, resp_s);
};
()
}

init { }

next(_: ()) { }
}

pub proc Top = Foo<true>;
)";

ConvertOptions options;
options.lower_to_proc_scoped_channels = true;
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertOneFunctionForTest(program, "Top", options));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ConstantsWithConditionalsPlusStuff) {
constexpr std::string_view program =
R"(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__Multiply_0_next<_input: bits[32] in, _output: bits[32] out>(__state: (), init={()}) {
chan_interface _input(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _output(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
after_all.5: token = after_all(id=5)
literal.3: bits[1] = literal(value=1, id=3)
receive.6: (token, bits[32]) = receive(after_all.5, predicate=literal.3, channel=_input, id=6)
req: bits[32] = tuple_index(receive.6, index=1, id=9, pos=[(0,12,16)])
literal.10: bits[32] = literal(value=2, id=10, pos=[(0,13,23)])
tok: token = tuple_index(receive.6, index=0, id=8, pos=[(0,12,11)])
data: bits[32] = umul(req, literal.10, id=11, pos=[(0,13,17)])
__state: () = state_read(state_element=__state, id=2)
tuple.13: () = tuple(id=13, pos=[(0,11,19)])
__token: token = literal(value=token, id=1)
tuple.4: () = tuple(id=4, pos=[(0,8,6)])
tuple_index.7: token = tuple_index(receive.6, index=0, id=7)
tok__1: token = send(tok, data, predicate=literal.3, channel=_output, id=12)
next_value.14: () = next_value(param=__state, value=tuple.13, id=14)
}

top proc __test_module__Top_0_next<_req_r: bits[32] in, _resp_s: bits[32] out>(__state: (), init={()}) {
chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
proc_instantiation __test_module__Multiply_0_next_inst(_req_r, _resp_s, proc=__test_module__Multiply_0_next)
__state: () = state_read(state_element=__state, id=16)
__token: token = literal(value=token, id=15)
literal.17: bits[1] = literal(value=1, id=17)
tuple.18: () = tuple(id=18, pos=[(0,42,13)])
tuple.19: () = tuple(id=19, pos=[(0,45,6)])
next_value.20: () = next_value(param=__state, value=__state, id=20)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__Truthy_0_next<_req_r: () in, _resp_s: bits[1] out>(__state: (), init={()}) {
chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
after_all.5: token = after_all(id=5)
literal.3: bits[1] = literal(value=1, id=3)
receive.6: (token, ()) = receive(after_all.5, predicate=literal.3, channel=_req_r, id=6)
tok: token = tuple_index(receive.6, index=0, id=8, pos=[(0,24,13)])
literal.10: bits[1] = literal(value=1, id=10, pos=[(0,25,36)])
__state: () = state_read(state_element=__state, id=2)
tuple.12: () = tuple(id=12, pos=[(0,23,15)])
__token: token = literal(value=token, id=1)
tuple.4: () = tuple(id=4, pos=[(0,19,57)])
tuple_index.7: token = tuple_index(receive.6, index=0, id=7)
_d: () = tuple_index(receive.6, index=1, id=9, pos=[(0,24,18)])
tok__1: token = send(tok, literal.10, predicate=literal.3, channel=_resp_s, id=11)
next_value.13: () = next_value(param=__state, value=tuple.12, id=13)
}

top proc __test_module__Top_next<_req_r: () in, _resp_s: bits[1] out>(__state: (), init={()}) {
chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
proc_instantiation __test_module__Truthy_0_next_inst(_req_r, _resp_s, proc=__test_module__Truthy_0_next)
__state: () = state_read(state_element=__state, id=15)
tuple.20: () = tuple(id=20, pos=[(0,41,17)])
__token: token = literal(value=token, id=14)
literal.16: bits[1] = literal(value=1, id=16)
CONFIG: bits[1] = literal(value=1, id=17, pos=[(0,29,11)])
tuple.18: () = tuple(id=18, pos=[(0,31,26)])
tuple.19: () = tuple(id=19, pos=[(0,36,10)])
next_value.21: () = next_value(param=__state, value=tuple.20, id=21)
}