Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add DPI call intrinsic and lowering pass #7139

Merged
merged 11 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/Dialects/FIRRTL/FIRRTLIntrinsics.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,31 @@ ifdef USE_FORMAL_ONLY_CONSTRAINTS
`endif // USE_UNR_ONLY_CONSTRAINTS
endif // USE_FORMAL_ONLY_CONSTRAINTS
```

### circt_dpi_call

Call a DPI function. `clock` is optional and if `clock` is not provided,
the callee is invoked when input values are changed.
If provided, the dpi function is called at clock's posedge. The result values behave
like registers and the DPI function is used as a state transfer function of them.

`enable` operand is used to conditionally call the DPI since DPI call could be quite
more expensive than native constructs. When `enable` is low, results of unclocked
calls are undefined and evaluated into `X`. Users are expected to gate result values
by another `enable` to model a default value of results.

For clocked calls, a low enable means that its register state transfer function is
not called. Hence their values will not be modify in that clock.

| Parameter | Type | Description |
| ------------- | ------ | -------------------------------- |
| isClocked | int | Set 1 if the dpi call is clocked |
| functionName | string | Specify the function name |


| Port | Direction | Type | Description |
| ----------------- | --------- | -------- | ------------------------------- |
| clock (optional) | input | Clock | Optional clock operand |
| enable | input | UInt<1> | Enable signal |
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enable is not optional at Chisel level to follow the same design as verification intrinsic.

| ... | input | Signals | Arguments to DPI function call |
| result (optional) | output | Signal | Optional result of the dpi call |
19 changes: 19 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,23 @@ def HasBeenResetIntrinsicOp : FIRRTLOp<"int.has_been_reset", [Pure]> {
}


def DPICallIntrinsicOp : FIRRTLOp<"int.dpi.call",
[AttrSizedOperandSegments]> {
let summary = "Import and call DPI function";
let description = [{
The `int.dpi.call` intrinsic calls an external function.
See Sim dialect DPI call op.
}];

let arguments = (ins StrAttr:$functionName,
Optional<NonConstClockType>:$clock,
Optional<NonConstUInt1Type>:$enable,
Variadic<PassiveType>:$inputs);
let results = (outs Optional<PassiveType>:$result);
let assemblyFormat = [{
$functionName `(` $inputs `)` (`clock` $clock^)? (`enable` $enable^)?
attr-dict `:` functional-type($inputs, results)
}];
}

#endif // CIRCT_DIALECT_FIRRTL_FIRRTLINTRINSICS_TD
2 changes: 2 additions & 0 deletions include/circt/Dialect/FIRRTL/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ std::unique_ptr<mlir::Pass> createCreateCompanionAssume();

std::unique_ptr<mlir::Pass> createModuleSummaryPass();

std::unique_ptr<mlir::Pass> createLowerDPIPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "circt/Dialect/FIRRTL/Passes.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions include/circt/Dialect/FIRRTL/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,10 @@ def ModuleSummary :
let constructor = "circt::firrtl::createModuleSummaryPass()";
}

def LowerDPI : Pass<"firrtl-lower-dpi", "firrtl::CircuitOp"> {
let summary = "Lower DPI intrinsic into Sim DPI operations";
let constructor = "circt::firrtl::createLowerDPIPass()";
let dependentDialects = ["hw::HWDialect", "seq::SeqDialect", "sim::SimDialect"];
}

#endif // CIRCT_DIALECT_FIRRTL_PASSES_TD
2 changes: 1 addition & 1 deletion include/circt/Dialect/Sim/SimOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def FatalOp : SimOp<"fatal"> {

def DPIFuncOp : SimOp<"func.dpi",
[IsolatedFromAbove, Symbol, OpAsmOpInterface,
FunctionOpInterface, HasParent<"mlir::ModuleOp">]> {
FunctionOpInterface]> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar change to 09df7f5. It was necessary to put DPI func op into circuit op to make the lowering composable.

let summary = "A System Verilog function";
let description = [{
`sim.func.dpi` models an external function in a core dialect.
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,45 @@ class CirctUnclockedAssumeConverter : public IntrinsicConverter {
}
};

class CirctDPICallConverter : public IntrinsicConverter {
static bool getIsClocked(GenericIntrinsic gi) {
return !gi.getParamValue<IntegerAttr>("isClocked").getValue().isZero();
}

public:
using IntrinsicConverter::IntrinsicConverter;

bool check(GenericIntrinsic gi) override {
if (gi.hasNParam(2) || gi.namedIntParam("isClocked") ||
gi.namedParam("functionName"))
return true;
auto isClocked = getIsClocked(gi);
// If clocked, the first operand must be a clock.
if (isClocked && gi.typedInput<ClockType>(0))
return true;
// Enable must be UInt<1>.
if (gi.sizedInput<UIntType>(isClocked, 1))
return true;

return false;
}

void convert(GenericIntrinsic gi, GenericIntrinsicOpAdaptor adaptor,
PatternRewriter &rewriter) override {
auto isClocked = getIsClocked(gi);
auto functionName = gi.getParamValue<StringAttr>("functionName");
// Clock and enable are optional.
Value clock = isClocked ? adaptor.getOperands()[0] : Value();
Value enable = adaptor.getOperands()[static_cast<size_t>(isClocked)];

auto inputs =
adaptor.getOperands().drop_front(static_cast<size_t>(isClocked) + 1);

rewriter.replaceOpWithNewOp<DPICallIntrinsicOp>(
gi.op, gi.op.getResultTypes(), functionName, clock, enable, inputs);
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -704,4 +743,5 @@ void FIRRTLIntrinsicLoweringDialectInterface::populateIntrinsicLowerings(
lowering.add<CirctCoverConverter>("circt.chisel_cover", "circt_chisel_cover");
lowering.add<CirctUnclockedAssumeConverter>("circt.unclocked_assume",
"circt_unclocked_assume");
lowering.add<CirctDPICallConverter>("circt.dpi_call", "circt_dpi_call");
}
3 changes: 3 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
LowerAnnotations.cpp
LowerCHIRRTL.cpp
LowerClasses.cpp
LowerDPI.cpp
LowerIntmodules.cpp
LowerIntrinsics.cpp
LowerLayers.cpp
Expand Down Expand Up @@ -69,6 +70,8 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
CIRCTEmit
CIRCTHW
CIRCTOM
CIRCTSeq
CIRCTSim
CIRCTSV
CIRCTSupport
MLIRIR
Expand Down
158 changes: 158 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/LowerDPI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//===- LowerDPI.cpp - Lower to DPI to Sim dialects ------------------------===//
uenoku marked this conversation as resolved.
Show resolved Hide resolved
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the LowerDPI pass.
//
//===----------------------------------------------------------------------===//

#include "PassDetails.h"
#include "circt/Dialect/FIRRTL/FIRRTLDialect.h"
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h"
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
#include "circt/Dialect/FIRRTL/Namespace.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Threading.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/MapVector.h"

using namespace mlir;
using namespace llvm;
using namespace circt;
using namespace circt::firrtl;

struct LowerDPIPass : public LowerDPIBase<LowerDPIPass> {
void runOnOperation() override;
};

void LowerDPIPass::runOnOperation() {
auto circuitOp = getOperation();

CircuitNamespace nameSpace(circuitOp);
MapVector<StringAttr, SmallVector<DPICallIntrinsicOp>> funcNameToCallSites;
{
// A helper struct to collect DPI calls in the circuit.
struct DpiCallCollections {
FModuleOp module;
SmallVector<DPICallIntrinsicOp> dpiOps;
};

SmallVector<DpiCallCollections, 0> collections;
collections.reserve(64);

for (auto module : circuitOp.getOps<FModuleOp>())
collections.push_back(DpiCallCollections{module, {}});

parallelForEach(&getContext(), collections, [](auto &result) {
result.module.walk(
[&](DPICallIntrinsicOp dpi) { result.dpiOps.push_back(dpi); });
});

for (auto &collection : collections)
for (auto dpi : collection.dpiOps)
funcNameToCallSites[dpi.getFunctionNameAttr()].push_back(dpi);
}

for (auto [name, calls] : funcNameToCallSites) {
auto firstDPICallop = calls.front();
// Construct DPI func op.
auto inputTypes = firstDPICallop.getInputs().getTypes();
auto outputTypes = firstDPICallop.getResultTypes();
SmallVector<hw::ModulePort> ports;
ImplicitLocOpBuilder builder(firstDPICallop.getLoc(),
circuitOp.getOperation());
ports.reserve(inputTypes.size() + outputTypes.size());

// Add input arguments.
for (auto [idx, inType] : llvm::enumerate(inputTypes)) {
hw::ModulePort port;
port.dir = hw::ModulePort::Direction::Input;
port.name = builder.getStringAttr(Twine("in_") + Twine(idx));
port.type = lowerType(inType);
ports.push_back(port);
}

// Add output arguments.
for (auto [idx, outType] : llvm::enumerate(outputTypes)) {
hw::ModulePort port;
port.dir = hw::ModulePort::Direction::Output;
port.name = builder.getStringAttr(Twine("out_") + Twine(idx));
port.type = lowerType(outType);
ports.push_back(port);
}

auto modType = hw::ModuleType::get(&getContext(), ports);
auto funcSymbol =
nameSpace.newName(firstDPICallop.getFunctionNameAttr().getValue());
builder.setInsertionPointToStart(circuitOp.getBodyBlock());
auto sim = builder.create<sim::DPIFuncOp>(
funcSymbol, modType, ArrayAttr(), ArrayAttr(),
firstDPICallop.getFunctionNameAttr());
sim.setPrivate();

auto lowerCall = [&builder, funcSymbol](DPICallIntrinsicOp dpiOp) {
auto getLowered = [&](Value value) -> Value {
// Insert an unrealized conversion to cast FIRRTL type to HW type.
if (!value)
return value;
auto type = lowerType(value.getType());
return builder.create<mlir::UnrealizedConversionCastOp>(type, value)
->getResult(0);
};
builder.setInsertionPoint(dpiOp);
auto clock = getLowered(dpiOp.getClock());
auto enable = getLowered(dpiOp.getEnable());
SmallVector<Value, 4> inputs;
inputs.reserve(dpiOp.getInputs().size());
for (auto input : dpiOp.getInputs())
inputs.push_back(getLowered(input));

SmallVector<Type> outputTypes;
if (dpiOp.getResult())
outputTypes.push_back(lowerType(dpiOp.getResult().getType()));

auto call = builder.create<sim::DPICallOp>(outputTypes, funcSymbol, clock,
enable, inputs);
if (!call.getResults().empty()) {
// Insert unrealized conversion cast HW type to FIRRTL type.
auto result = builder
.create<mlir::UnrealizedConversionCastOp>(
dpiOp.getResult().getType(), call.getResult(0))
->getResult(0);
dpiOp.getResult().replaceAllUsesWith(result);
}
dpiOp.erase();
};

lowerCall(firstDPICallop);
for (auto dpiOp : llvm::ArrayRef(calls).drop_front()) {
// Check that all DPI declaration match.
// TODO: This should be implemented as a verifier once function is added
// to FIRRTL.
if (dpiOp.getInputs().getTypes() != inputTypes) {
auto diag = firstDPICallop.emitOpError()
<< "DPI function " << firstDPICallop.getFunctionNameAttr()
<< " input types don't match ";
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
return signalPassFailure();
}
if (dpiOp.getResultTypes() != outputTypes) {
auto diag = firstDPICallop.emitOpError()
<< "DPI function " << firstDPICallop.getFunctionNameAttr()
<< " output types don't match";
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
return signalPassFailure();
}
lowerCall(dpiOp);
}
}
}

std::unique_ptr<mlir::Pass> circt::firrtl::createLowerDPIPass() {
return std::make_unique<LowerDPIPass>();
}
8 changes: 8 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/PassDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ namespace emit {
class EmitDialect;
} // namespace emit

namespace seq {
class SeqDialect;
} // namespace seq

namespace sim {
class SimDialect;
} // namespace sim

namespace sv {
class SVDialect;
} // namespace sv
Expand Down
1 change: 1 addition & 0 deletions lib/Firtool/Firtool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ LogicalResult firtool::populateLowFIRRTLToHW(mlir::PassManager &pm,
// RefType ports and ops.
pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerXMRPass());

pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerDPIPass());
pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerClassesPass());
pm.nest<firrtl::CircuitOp>().addPass(om::createVerifyObjectFieldsPass());

Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/FIRRTL/lower-dpi-error.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: circt-opt -firrtl-lower-dpi %s -verify-diagnostics --split-input-file

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
firrtl.module @DPI(in %in_0: !firrtl.uint<8>, in %in_1: !firrtl.uint<16>) attributes {convention = #firrtl<convention scalarized>} {
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" input types don't match}}
firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> ()
// expected-note @below {{mismatched caller is here}}
firrtl.int.dpi.call "foo"(%in_1) : (!firrtl.uint<16>) -> ()
}
}

// -----

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
firrtl.module @DPI(in %in_0: !firrtl.uint<8>) attributes {convention = #firrtl<convention scalarized>} {
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" output types don't match}}
%0 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<16>)
// expected-note @below {{mismatched caller is here}}
%1 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<8>)
}
}

34 changes: 34 additions & 0 deletions test/Dialect/FIRRTL/lower-dpi.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: circt-opt -firrtl-lower-dpi %s | FileCheck %s

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
// CHECK-NEXT: sim.func.dpi private @unclocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "unclocked_result"}
// CHECK-NEXT: sim.func.dpi private @clocked_void(in %in_0 : i2, in %in_1 : i2) attributes {verilogName = "clocked_void"}
// CHECK-NEXT: sim.func.dpi private @clocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "clocked_result"}
// CHECK-LABEL: firrtl.module @DPI
firrtl.module @DPI(in %clock: !firrtl.clock, in %enable: !firrtl.uint<1>, in %in_0: !firrtl.uint<2>, in %in_1: !firrtl.uint<2>, out %out_0: !firrtl.uint<2>, out %out_1: !firrtl.uint<2>) attributes {convention = #firrtl<convention scalarized>} {
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %2 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %3 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: %4 = sim.func.dpi.call @clocked_result(%2, %3) clock %0 enable %1 : (i2, i2) -> i2
// CHECK-NEXT: %5 = builtin.unrealized_conversion_cast %4 : i2 to !firrtl.uint<2>
// CHECK-NEXT: %6 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock
// CHECK-NEXT: %7 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %8 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: sim.func.dpi.call @clocked_void(%8, %9) clock %6 enable %7 : (i2, i2) -> ()
// CHECK-NEXT: %10 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %11 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %12 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: %13 = sim.func.dpi.call @unclocked_result(%11, %12) enable %10 : (i2, i2) -> i2
// CHECK-NEXT: %14 = builtin.unrealized_conversion_cast %13 : i2 to !firrtl.uint<2>
// CHECK-NEXT:firrtl.matchingconnect %out_0, %5 : !firrtl.uint<2>
// CHECK-NEXT:firrtl.matchingconnect %out_1, %14 : !firrtl.uint<2>
%0 = firrtl.int.dpi.call "clocked_result"(%in_0, %in_1) clock %clock enable %enable {name = "result1"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
firrtl.int.dpi.call "clocked_void"(%in_0, %in_1) clock %clock enable %enable : (!firrtl.uint<2>, !firrtl.uint<2>) -> ()
%1 = firrtl.int.dpi.call "unclocked_result"(%in_0, %in_1) enable %enable {name = "result2"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
firrtl.matchingconnect %out_0, %0 : !firrtl.uint<2>
firrtl.matchingconnect %out_1, %1 : !firrtl.uint<2>
}
}
Loading
Loading