Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Handshake] Remove handshake.select #5045

Merged
merged 2 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions include/circt/Dialect/Handshake/HandshakeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -481,39 +481,6 @@ def ConditionalBranchOp : Handshake_Op<"cond_br", [
}];
}

def SelectOp : Handshake_Op<"select", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
DeclareOpInterfaceMethods<ControlInterface, ["isControl"]>,
DeclareOpInterfaceMethods<NamedIOInterface, ["getOperandName"]>,
TypesMatchWith<"data operand type matches true branch result type",
"trueOperand", "falseOperand", "$_self">,
TypesMatchWith<"data operand type matches false branch result type",
"falseOperand", "result", "$_self">
]> {
let summary = "Select operation";
let description = [{
The select operation will select between two inputs based on an input
conditional. The select operation differs from a mux in that
1. All operands must be valid before the operation can transact
2. All operands will be transacted at simultaneously

The 'select' operation is intended to handle 'std.select' and other
ternary-like operators, which considers strictly dataflow. The 'mux' operator
considers control+dataflow between blocks.

Example:
```mlir
%res = select %cond, %true, %false : i32
```
}];

let arguments = (ins I1 : $condOperand,
AnyType : $trueOperand, AnyType : $falseOperand);
let results = (outs AnyType : $result);
let hasCustomAssemblyFormat = 1;
}

def SinkOp : Handshake_Op<"sink", [
SOSTInterface, DeclareOpInterfaceMethods<ExecutableOpInterface>
]> {
Expand Down
17 changes: 8 additions & 9 deletions include/circt/Dialect/Handshake/Visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,10 @@ class HandshakeVisitor {
// Handshake nodes.
BranchOp, BufferOp, ConditionalBranchOp, ConstantOp, ControlMergeOp,
ForkOp, FuncOp, InstanceOp, JoinOp, LazyForkOp, LoadOp, MemoryOp,
ExternalMemoryOp, MergeOp, MuxOp, ReturnOp, SinkOp,
handshake::SelectOp, SourceOp, StoreOp, SyncOp, PackOp, UnpackOp>(
[&](auto opNode) -> ResultType {
return thisCast->visitHandshake(opNode, args...);
})
ExternalMemoryOp, MergeOp, MuxOp, ReturnOp, SinkOp, SourceOp,
StoreOp, SyncOp, PackOp, UnpackOp>([&](auto opNode) -> ResultType {
return thisCast->visitHandshake(opNode, args...);
})
.Default([&](auto opNode) -> ResultType {
return thisCast->visitInvalidOp(op, args...);
});
Expand Down Expand Up @@ -73,7 +72,6 @@ class HandshakeVisitor {
HANDLE(LazyForkOp);
HANDLE(LoadOp);
HANDLE(MemoryOp);
HANDLE(handshake::SelectOp);
HANDLE(ExternalMemoryOp);
HANDLE(MergeOp);
HANDLE(MuxOp);
Expand Down Expand Up @@ -106,9 +104,10 @@ class StdExprVisitor {
arith::CmpIOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
arith::DivSIOp, arith::RemSIOp, arith::DivUIOp, arith::RemUIOp,
arith::XOrIOp, arith::AndIOp, arith::OrIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp>([&](auto opNode) -> ResultType {
return thisCast->visitStdExpr(opNode, args...);
})
arith::ShRSIOp, arith::ShRUIOp, arith::SelectOp>(
[&](auto opNode) -> ResultType {
return thisCast->visitStdExpr(opNode, args...);
})
.Default([&](auto opNode) -> ResultType {
return thisCast->visitInvalidOp(op, args...);
});
Expand Down
2 changes: 1 addition & 1 deletion integration_test/Dialect/Handshake/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def checkOutputs(self, results):
assert (self.isReady())
for res in results:
await self.waitUntilValid()
assert (self.data.value == res)
assert self.data.value == res, f"Expected {res}, got {self.data.value}"
await RisingEdge(self.dut.clock)


Expand Down
14 changes: 7 additions & 7 deletions integration_test/Dialect/Handshake/max/max.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,21 @@

// Computes the maximum of all inputs
func.func @top(%in0: i64, %in1: i64, %in2: i64, %in3: i64, %in4: i64, %in5: i64, %in6: i64, %in7: i64) -> i64 {
%c0 = arith.cmpi slt, %in0, %in1 : i64
%c0 = arith.cmpi sge, %in0, %in1 : i64
%t0 = arith.select %c0, %in0, %in1 : i64
%c1 = arith.cmpi slt, %in2, %in3 : i64
%c1 = arith.cmpi sge, %in2, %in3 : i64
%t1 = arith.select %c1, %in2, %in3 : i64
%c2 = arith.cmpi slt, %in4, %in5 : i64
%c2 = arith.cmpi sge, %in4, %in5 : i64
%t2 = arith.select %c2, %in4, %in5 : i64
%c3 = arith.cmpi slt, %in6, %in7 : i64
%c3 = arith.cmpi sge, %in6, %in7 : i64
%t3 = arith.select %c3, %in6, %in7 : i64

%c4 = arith.cmpi slt, %t0, %t1 : i64
%c4 = arith.cmpi sge, %t0, %t1 : i64
%t4 = arith.select %c4, %t0, %t1 : i64
%c5 = arith.cmpi slt, %t2, %t3 : i64
%c5 = arith.cmpi sge, %t2, %t3 : i64
%t5 = arith.select %c5, %t2, %t3 : i64

%c6 = arith.cmpi slt, %t4, %t5 : i64
%c6 = arith.cmpi sge, %t4, %t5 : i64
%t6 = arith.select %c6, %t4, %t5 : i64
return %t6 : i64
}
93 changes: 40 additions & 53 deletions lib/Conversion/HandshakeToFIRRTL/HandshakeToFIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ class StdExprBuilder : public StdExprVisitor<StdExprBuilder, bool> {
bool visitStdExpr(arith::ExtSIOp op);
bool visitStdExpr(arith::TruncIOp op);
bool visitStdExpr(arith::IndexCastOp op);
bool visitStdExpr(arith::SelectOp op);

#define HANDLE(OPTYPE, FIRRTLTYPE) \
bool visitStdExpr(OPTYPE op) { return buildBinaryLogic<FIRRTLTYPE>(), true; }
Expand Down Expand Up @@ -1033,6 +1034,45 @@ void StdExprBuilder::buildBinaryLogic() {
rewriter.create<ConnectOp>(insertLoc, arg1Ready, argReadyOp);
}

bool StdExprBuilder::visitStdExpr(arith::SelectOp op) {
ValueVector sel = portList[0];
Value selValid = sel[0];
Value selReady = sel[1];
Value selData = sel[2];
ValueVector t = portList[1];
Value tValid = t[0];
Value tReady = t[1];
Value tData = t[2];
ValueVector f = portList[2];
Value fValid = f[0];
Value fReady = f[1];
Value fData = f[2];

llvm::SmallVector<ValueVector *> inputs = {&sel, &t, &f};

ValueVector result = portList[3];
Value resultValid = result[0];
Value resultReady = result[1];
Value resultData = result[2];

// Data mux.
auto mux = rewriter.create<MuxPrimOp>(insertLoc, selData, tData, fData);
rewriter.create<ConnectOp>(insertLoc, resultData, mux);

// Join logic on the in- and outputs.
auto valid = rewriter.create<AndPrimOp>(
insertLoc, tValid.getType(), selValid,
rewriter.create<AndPrimOp>(insertLoc, tValid, fValid));
auto ready = rewriter.create<AndPrimOp>(insertLoc, resultReady.getType(),
resultReady, valid);

rewriter.create<ConnectOp>(insertLoc, resultValid, valid);
rewriter.create<ConnectOp>(insertLoc, selReady, ready);
rewriter.create<ConnectOp>(insertLoc, tReady, ready);
rewriter.create<ConnectOp>(insertLoc, fReady, ready);
return true;
}

//===----------------------------------------------------------------------===//
// Handshake Builder class
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1061,7 +1101,6 @@ class HandshakeBuilder : public HandshakeVisitor<HandshakeBuilder, bool> {
bool visitHandshake(ExternalMemoryOp op);
bool visitHandshake(MergeOp op);
bool visitHandshake(MuxOp op);
bool visitHandshake(handshake::SelectOp op);
bool visitHandshake(SinkOp op);
bool visitHandshake(SourceOp op);
bool visitHandshake(SyncOp op);
Expand Down Expand Up @@ -1357,58 +1396,6 @@ bool HandshakeBuilder::visitHandshake(MuxOp op) {
return true;
}

bool HandshakeBuilder::visitHandshake(handshake::SelectOp op) {
ValueVector selectSubfields = portList[0];
Value selectValid = selectSubfields[0];
Value selectReady = selectSubfields[1];
Value selectData = selectSubfields[2];

ValueVector resultSubfields = portList[3];
Value resultValid = resultSubfields[0];
Value resultReady = resultSubfields[1];
Value resultData = resultSubfields[2];

ValueVector trueSubfields = portList[1];
Value trueValid = trueSubfields[0];
Value trueReady = trueSubfields[1];
Value trueData = trueSubfields[2];

ValueVector falseSubfields = portList[2];
Value falseValid = falseSubfields[0];
Value falseReady = falseSubfields[1];
Value falseData = falseSubfields[2];

auto bitType = UIntType::get(rewriter.getContext(), 1);

// Mux the true and false data.
auto muxedData =
createMuxTree({falseData, trueData}, selectData, insertLoc, rewriter);

// Connect the selected data signal to the result data.
rewriter.create<ConnectOp>(insertLoc, resultData, muxedData);

// 'and' the arg valids and select valid
Value allValid =
rewriter.create<WireOp>(insertLoc, bitType, "allValid").getResult();
buildReductionTree<AndPrimOp>({trueValid, falseValid, selectValid}, allValid);

// Connect that to the result valid.
rewriter.create<ConnectOp>(insertLoc, resultValid, allValid);

// 'and' the result valid with the result ready.
auto resultValidAndReady =
rewriter.create<AndPrimOp>(insertLoc, bitType, allValid, resultReady);

// Connect that to the 'ready' signal of all inputs. This implies that all
// inputs + select is transacted when all are valid (and the output is ready),
// but only the selected data is forwarded.
rewriter.create<ConnectOp>(insertLoc, selectReady, resultValidAndReady);
rewriter.create<ConnectOp>(insertLoc, trueReady, resultValidAndReady);
rewriter.create<ConnectOp>(insertLoc, falseReady, resultValidAndReady);

return true;
}

/// Please refer to test_merge.mlir test case.
/// Lowers the MergeOp into primitive FIRRTL ops. This is a simplification of
/// the ControlMergeOp lowering, since it doesn't need to wait for more than one
Expand Down
40 changes: 2 additions & 38 deletions lib/Conversion/HandshakeToHW/HandshakeToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,43 +1022,6 @@ class MuxConversionPattern : public HandshakeConversionPattern<MuxOp> {
};
};

class SelectConversionPattern
: public HandshakeConversionPattern<handshake::SelectOp> {
public:
using HandshakeConversionPattern<
handshake::SelectOp>::HandshakeConversionPattern;
void buildModule(handshake::SelectOp op, BackedgeBuilder &bb, RTLBuilder &s,
hw::HWModulePortAccessor &ports) const override {
auto unwrappedIO = unwrapIO(s, bb, ports);

// Extract select signal from the unwrapped IO.
auto select = unwrappedIO.inputs[0];
auto trueIn = unwrappedIO.inputs[1];
auto falseIn = unwrappedIO.inputs[2];
auto out = unwrappedIO.outputs[0];

// Mux the true and false data to the output.
auto muxedData = s.mux(select.data, {falseIn.data, trueIn.data});
out.data->setValue(muxedData);

// 'and' the arg valids and select valid
Value allValid = s.bAnd({select.valid, trueIn.valid, falseIn.valid});

// Connect that to the result valid.
out.valid->setValue(allValid);

// 'and' the result valid with the result ready.
auto resValidAndReady = s.bAnd({allValid, out.ready});

// Connect that to the 'ready' signal of all inputs. This implies that all
// inputs + select is transacted when all are valid (and the output is
// ready), but only the selected data is forwarded.
select.ready->setValue(resValidAndReady);
trueIn.ready->setValue(resValidAndReady);
falseIn.ready->setValue(resValidAndReady);
};
};

class ReturnConversionPattern
: public OpConversionPattern<handshake::ReturnOp> {
public:
Expand Down Expand Up @@ -1901,11 +1864,12 @@ static LogicalResult convertFuncOp(ESITypeConverter &typeConverter,
UnitRateConversionPattern<arith::ShLIOp, comb::OrOp>,
UnitRateConversionPattern<arith::ShRUIOp, comb::ShrUOp>,
UnitRateConversionPattern<arith::ShRSIOp, comb::ShrSOp>,
UnitRateConversionPattern<arith::SelectOp, comb::MuxOp>,
// HW operations.
StructCreateConversionPattern,
// Handshake operations.
ConditionalBranchConversionPattern, MuxConversionPattern,
SelectConversionPattern, PackConversionPattern, UnpackConversionPattern,
PackConversionPattern, UnpackConversionPattern,
ComparisonConversionPattern, BufferConversionPattern,
SourceConversionPattern, SinkConversionPattern, ConstantConversionPattern,
MergeConversionPattern, ControlMergeConversionPattern,
Expand Down
28 changes: 1 addition & 27 deletions lib/Conversion/StandardToHandshake/StandardToHandshake.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1715,29 +1715,6 @@ static LogicalResult lowerFuncOp(func::FuncOp funcOp, MLIRContext *ctx,
return success();
}

namespace {
struct ConvertSelectOps : public OpConversionPattern<mlir::arith::SelectOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<handshake::SelectOp>(op, adaptor.getCondition(),
adaptor.getFalseValue(),
adaptor.getTrueValue());
return success();
};
};
} // namespace

LogicalResult handshake::postDataflowConvert(Operation *op) {
MLIRContext *context = op->getContext();
ConversionTarget target(*context);
target.addLegalDialect<handshake::HandshakeDialect>();
target.addIllegalOp<mlir::arith::SelectOp>();
RewritePatternSet patterns(context);
patterns.insert<ConvertSelectOps>(context);
return applyPartialConversion(op, target, std::move(patterns));
}

namespace {

Expand Down Expand Up @@ -1765,11 +1742,8 @@ struct StandardToHandshakePass

// Legalize the resulting regions, removing basic blocks and performing
// any simple conversions.
for (auto func : m.getOps<handshake::FuncOp>()) {
for (auto func : m.getOps<handshake::FuncOp>())
removeBasicBlocks(func);
if (failed(postDataflowConvert(func)))
return signalPassFailure();
}
}
};

Expand Down
48 changes: 0 additions & 48 deletions lib/Dialect/Handshake/HandshakeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,54 +832,6 @@ bool ConditionalBranchOp::isControl() {
getDataOperand());
}

ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type dataType;
SmallVector<Type> operandTypes;
llvm::SMLoc allOperandLoc = parser.getCurrentLocation();
if (parser.parseOperandList(allOperands) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(dataType))
return failure();

if (allOperands.size() != 3)
return parser.emitError(parser.getCurrentLocation(),
"Expected exactly 3 operands");

result.addTypes({dataType});
operandTypes.push_back(IntegerType::get(parser.getContext(), 1));
operandTypes.push_back(dataType);
operandTypes.push_back(dataType);
if (parser.resolveOperands(allOperands, operandTypes, allOperandLoc,
result.operands))
return failure();
return success();
}

void SelectOp::print(OpAsmPrinter &p) {
Type type = getTrueOperand().getType();
p << " " << getOperands();
p.printOptionalAttrDict((*this)->getAttrs());
p << " : " << type;
}

std::string handshake::SelectOp::getOperandName(unsigned int idx) {
switch (idx) {
case 0:
return "sel";
case 1:
return "true";
case 2:
return "false";
default:
llvm_unreachable("Expected exactly 3 operands");
}
}

bool SelectOp::isControl() {
return getTrueOperand().getType().isa<NoneType>();
}

ParseResult SinkOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand, 4> allOperands;
Type type;
Expand Down
Loading