Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL][LOA] Handle inner symbols on open aggs hw components. #5709

Merged
merged 3 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 110 additions & 45 deletions lib/Dialect/FIRRTL/Transforms/LowerOpenAggs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct PortMappingInfo {
/// HW-only projection is empty, and not leaf.
SmallVector<uint64_t, 0> mapToNullInteriors;

hw::InnerSymAttr newSym = {};

/// Determine number of types this argument maps to.
size_t count(bool includeErased = false) const {
if (identity)
Expand Down Expand Up @@ -124,7 +126,12 @@ void PortMappingInfo::print(llvm::raw_ostream &os) const {
llvm::interleaveComma(fields, os);
os << ">, mappedToNull: <";
llvm::interleaveComma(mapToNullInteriors, os);
os << ">]]";
os << ">, sym: ";
if (newSym)
os << newSym;
else
os << "()";
os << " ]]";
}

template <typename Range>
Expand Down Expand Up @@ -176,10 +183,11 @@ class Visitor : public FIRRTLVisitor<Visitor, LogicalResult> {
LogicalResult visitInvalidOp(Operation *op) { return visitUnhandledOp(op); }

private:
/// Convert a type to its HW-only projection.,
/// Convert a type to its HW-only projection, adjusting symbols.
/// Gather non-hw elements encountered and their names / positions.
/// Returns a PortMappingInfo with its findings.
PortMappingInfo mapPortType(Type type);
FailureOr<PortMappingInfo> mapPortType(Type type, Location errorLoc,
hw::InnerSymAttr sym = {});

MLIRContext *context;

Expand All @@ -199,8 +207,13 @@ class Visitor : public FIRRTLVisitor<Visitor, LogicalResult> {
LogicalResult Visitor::visit(FModuleLike mod) {
auto ports = mod.getPorts();

SmallVector<PortMappingInfo, 16> portMappings(
llvm::map_range(ports, [&](auto &p) { return mapPortType(p.type); }));
SmallVector<PortMappingInfo, 16> portMappings;
for (auto &port : ports) {
auto pmi = mapPortType(port.type, port.loc, port.sym);
if (failed(pmi))
return failure();
portMappings.push_back(*pmi);
}

/// Total number of types mapped to.
/// Include erased ports.
Expand Down Expand Up @@ -239,22 +252,18 @@ LogicalResult Visitor::visit(FModuleLike mod) {
if (pmi.hwType) {
auto newPort = port;
newPort.type = pmi.hwType;
newPort.sym = pmi.newSym;
newPorts.emplace_back(idxOfInsertPoint, newPort);

if (port.sym && llvm::any_of(port.sym, [&](auto &prop) {
return prop.getFieldID() != 0;
}))
return mlir::emitError(port.loc)
<< "symbols on fields of open aggregates not handled yet";
assert(!port.sym ||
(pmi.newSym && port.sym.size() == pmi.newSym.size()));

// If want to run this pass later, need to fixup annotations.
if (!port.annotations.empty())
return mlir::emitError(port.loc)
<< "annotations on open aggregates not handled yet";
} else {
if (port.sym)
return mlir::emitError(port.loc)
<< "symbol found on aggregate with no HW";
assert(!port.sym && !pmi.newSym);
if (!port.annotations.empty())
return mlir::emitError(port.loc)
<< "annotations found on aggregate with no HW";
Expand Down Expand Up @@ -448,8 +457,14 @@ LogicalResult Visitor::visitExpr(OpenSubindexOp op) {
LogicalResult Visitor::visitDecl(InstanceOp op) {
// Rewrite ports same strategy as for modules.

SmallVector<PortMappingInfo, 16> portMappings(llvm::map_range(
op.getResultTypes(), [&](auto type) { return mapPortType(type); }));
SmallVector<PortMappingInfo, 16> portMappings;

for (auto type : op.getResultTypes()) {
auto pmi = mapPortType(type, op.getLoc());
if (failed(pmi))
return failure();
portMappings.push_back(*pmi);
}

/// Total number of types mapped to.
size_t countWithErased = 0;
Expand Down Expand Up @@ -570,7 +585,8 @@ LogicalResult Visitor::visitDecl(InstanceOp op) {
// Type Conversion
//===----------------------------------------------------------------------===//

PortMappingInfo Visitor::mapPortType(Type type) {
FailureOr<PortMappingInfo> Visitor::mapPortType(Type type, Location errorLoc,
hw::InnerSymAttr sym) {
PortMappingInfo pi{false, {}, {}, {}};
auto ftype = type_dyn_cast<FIRRTLType>(type);
// Ports that aren't open aggregates are left alone.
Expand All @@ -579,22 +595,34 @@ PortMappingInfo Visitor::mapPortType(Type type) {
return pi;
}

SmallVector<hw::InnerSymPropertiesAttr> newProps;

// NOLINTBEGIN(misc-no-recursion)
auto recurse = [&](auto &&f, FIRRTLType type, const Twine &suffix = "",
bool flip = false,
uint64_t fieldID = 0) -> FIRRTLBaseType {
return TypeSwitch<FIRRTLType, FIRRTLBaseType>(type)
.Case<FIRRTLBaseType>([](auto base) { return base; })
.template Case<OpenBundleType>(
[&](OpenBundleType obTy) -> FIRRTLBaseType {
bool flip = false, uint64_t fieldID = 0,
uint64_t newFieldID = 0) -> FailureOr<FIRRTLBaseType> {
auto newType =
TypeSwitch<FIRRTLType, FailureOr<FIRRTLBaseType>>(type)
.Case<FIRRTLBaseType>([](auto base) { return base; })
.template Case<OpenBundleType>([&](OpenBundleType obTy)
-> FailureOr<FIRRTLBaseType> {
SmallVector<BundleType::BundleElement> hwElements;
uint64_t id = 0;
for (const auto &[index, element] :
llvm::enumerate(obTy.getElements()))
if (auto base =
f(f, element.type, suffix + "_" + element.name.strref(),
flip ^ element.isFlip,
fieldID + obTy.getFieldID(index)))
hwElements.emplace_back(element.name, element.isFlip, base);
llvm::enumerate(obTy.getElements())) {
auto base =
f(f, element.type, suffix + "_" + element.name.strref(),
flip ^ element.isFlip, fieldID + obTy.getFieldID(index),
newFieldID + id + 1);
if (failed(base))
return failure();
if (*base) {
hwElements.emplace_back(element.name, element.isFlip, *base);
id += type_cast<hw::FieldIDTypeInterface>(*base)
.getMaxFieldID() +
1;
}
}

if (hwElements.empty()) {
pi.mapToNullInteriors.push_back(fieldID);
Expand All @@ -603,18 +631,25 @@ PortMappingInfo Visitor::mapPortType(Type type) {

return BundleType::get(context, hwElements, obTy.isConst());
})
.template Case<OpenVectorType>(
[&](OpenVectorType ovTy) -> FIRRTLBaseType {
.template Case<OpenVectorType>([&](OpenVectorType ovTy)
-> FailureOr<FIRRTLBaseType> {
uint64_t id = 0;
FIRRTLBaseType convert;
// Walk for each index to extract each leaf separately, but expect
// same hw-only type for all.
for (auto idx : llvm::seq<size_t>(0U, ovTy.getNumElements())) {
auto hwElementType =
f(f, ovTy.getElementType(), suffix + "_" + Twine(idx), flip,
fieldID + ovTy.getFieldID(idx));
assert((!convert || convert == hwElementType) &&
fieldID + ovTy.getFieldID(idx), newFieldID + id + 1);
if (failed(hwElementType))
return failure();
assert((!convert || convert == *hwElementType) &&
"expected same hw type for all elements");
convert = hwElementType;
convert = *hwElementType;
if (convert)
id += type_cast<hw::FieldIDTypeInterface>(convert)
dtzSiFive marked this conversation as resolved.
Show resolved Hide resolved
.getMaxFieldID() +
1;
}

if (!convert) {
Expand All @@ -625,23 +660,53 @@ PortMappingInfo Visitor::mapPortType(Type type) {
return FVectorType::get(convert, ovTy.getNumElements(),
ovTy.isConst());
})
.template Case<RefType>([&](auto ref) {
// Do this better, don't re-serialize so much?
auto f = NonHWField{ref, fieldID, flip, {}};
suffix.toVector(f.suffix);
pi.fields.emplace_back(std::move(f));
return FIRRTLBaseType{};
})
.Default([&](auto _) {
pi.mapToNullInteriors.push_back(fieldID);
return FIRRTLBaseType{};
});
.template Case<RefType>([&](auto ref) {
// Do this better, don't re-serialize so much?
auto f = NonHWField{ref, fieldID, flip, {}};
suffix.toVector(f.suffix);
pi.fields.emplace_back(std::move(f));
return FIRRTLBaseType{};
})
.Default([&](auto _) {
pi.mapToNullInteriors.push_back(fieldID);
return FIRRTLBaseType{};
});
if (failed(newType))
return failure();

// If there's a symbol on this, add it with adjusted fieldID.
if (sym)
if (auto symOnThis = sym.getSymIfExists(fieldID)) {
if (!*newType)
return mlir::emitError(errorLoc, "inner symbol ")
<< symOnThis << " mapped to non-HW type";
newProps.push_back(hw::InnerSymPropertiesAttr::get(
context, symOnThis, newFieldID,
StringAttr::get(context, "public")));
}
return newType;
};

pi.hwType = recurse(recurse, ftype);
auto hwType = recurse(recurse, ftype);
if (failed(hwType))
return failure();
pi.hwType = *hwType;

assert(pi.hwType != type);
// NOLINTEND(misc-no-recursion)

if (sym) {
assert(sym.size() == newProps.size());

if (!pi.hwType && !newProps.empty())
return mlir::emitError(errorLoc, "inner symbol on non-HW type");

llvm::sort(newProps, [](auto &p, auto &q) {
return p.getFieldID() < q.getFieldID();
});
pi.newSym = hw::InnerSymAttr::get(context, newProps);
}

return pi;
}

Expand Down
6 changes: 3 additions & 3 deletions test/Dialect/FIRRTL/lower-open-aggs-errors.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: circt-opt --pass-pipeline="builtin.module(firrtl.circuit(firrtl-lower-open-aggs))" %s --split-input-file --verify-diagnostics

firrtl.circuit "Symbol" {
// expected-error @below {{symbol found on aggregate with no HW}}
// expected-error @below {{inner symbol "bad" mapped to non-HW type}}
firrtl.module @Symbol(out %r : !firrtl.openbundle<p: probe<uint<1>>> sym @bad) {
%zero = firrtl.constant 0 : !firrtl.uint<1>
%ref = firrtl.ref.send %zero : !firrtl.uint<1>
Expand All @@ -13,8 +13,8 @@ firrtl.circuit "Symbol" {
// -----

firrtl.circuit "SymbolOnField" {
// expected-error @below {{symbols on fields of open aggregates not handled yet}}
firrtl.extmodule @SymbolOnField(out r : !firrtl.openbundle<p: probe<uint<1>>, x: uint<1>> sym [<@bad,2,public>])
// expected-error @below {{inner symbol "bad" mapped to non-HW type}}
firrtl.extmodule @SymbolOnField(out r : !firrtl.openbundle<p: probe<uint<1>>, x: uint<1>> sym [<@bad,1,public>])
}

// -----
Expand Down
10 changes: 10 additions & 0 deletions test/Dialect/FIRRTL/lower-open-aggs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,13 @@ firrtl.circuit "RefsOnlyAggFirstLevel" {
firrtl.ref.define %0, %3 : !firrtl.probe<uint<1>>
}
}

// -----

// CHECK-LABEL: circuit "SymbolOnField"
firrtl.circuit "SymbolOnField" {
// CHECK: @SymbolOnField
// CHECK-SAME: (out r: !firrtl.bundle<x: uint<1>> sym [<@sym,1,public>],
// CHECK-SAME: out r_p: !firrtl.probe<uint<1>>)
firrtl.extmodule @SymbolOnField(out r : !firrtl.openbundle<p: probe<uint<1>>, x: uint<1>> sym [<@sym,2,public>])
}
8 changes: 7 additions & 1 deletion test/firtool/refs-in-aggs.fir
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ circuit Bundle :
module Probe :
input in : {a : UInt<1>, b : UInt<1>[2]}
output r : {a : Probe<{a : UInt<1>, b : UInt<1>[2]}>, b : Probe<{a : UInt<1>, b : UInt<1>[2]}>} ; bundle of probes of bundles (of UInt, vec)
output mixed : {a : UInt<1>, flip x : {flip p: Probe<{a : UInt<1>, b : UInt<1>[2]}>, flip data: UInt<1>}[2], b : UInt<1>[2]} ; mixed
output mixed : {a : UInt<1>,
flip x : {flip p: Probe<{a : UInt<1>,
b : UInt<1>[2]
}>,
flip data: UInt<1>
}[2],
b : UInt<1>[2]} ; mixed
output nohw : {x : {p: Probe<{a : UInt<1>, b : UInt<1>[2]}>}[2]} ; non-hw-only

inst c1 of Child
Expand Down