Skip to content

Commit

Permalink
[Ibis] Call preparation pass -- part 1 (#5683)
Browse files Browse the repository at this point in the history
Lowering the Ibis control flow starts with converting all calls to the DC dialect's values with one struct representing the packed arguments.

In the interest of incremental development, this only implements narrowing the arguments to one struct. Part 2 will then wrap them in DC values.
  • Loading branch information
teqdruid committed Jul 27, 2023
1 parent b49e82a commit d4b591e
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 5 deletions.
9 changes: 4 additions & 5 deletions include/circt/Dialect/Ibis/IbisOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def MethodOp : IbisOp<"method", [
SingleBlockImplicitTerminator<"ReturnOp">,
HasParent<"ClassOp">]> {

let summary = "Ibis function";
let summary = "Ibis method";
let description = [{
Ibis functions are a lot like software functions: a list of named arguments
and one unnamed return value.
Expand Down Expand Up @@ -112,17 +112,16 @@ def ReturnOp : IbisOp<"return", [
}

def CallOp : IbisOp<"call", [CallOpInterface]> {

let summary = "Ibis function call";
let summary = "Ibis method call";
let description = [{
Dispatch a call to an Ibis function.
Dispatch a call to an Ibis method.
}];

let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);

let extraClassDeclaration = [{
/// Get the argument operands to the called function.
/// Get the argument operands to the called method.
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/Ibis/IbisPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
namespace circt {
namespace ibis {

std::unique_ptr<Pass> createCallPrepPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "circt/Dialect/Ibis/IbisPasses.h.inc"
Expand Down
8 changes: 8 additions & 0 deletions include/circt/Dialect/Ibis/IbisPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,12 @@

include "mlir/Pass/PassBase.td"

def IbisCallPrep : Pass<"ibis-call-prep", "ModuleOp"> {
let summary = "Convert ibis method calls to use `dc.value`";

let constructor = "circt::ibis::createCallPrepPass()";
let dependentDialects = [
"circt::hw::HWDialect", "circt::dc::DCDialect"];
}

#endif // CIRCT_DIALECT_IBIS_PASSES_TD
2 changes: 2 additions & 0 deletions include/circt/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "circt/Dialect/FSM/FSMPasses.h"
#include "circt/Dialect/HW/HWPasses.h"
#include "circt/Dialect/Handshake/HandshakePasses.h"
#include "circt/Dialect/Ibis/IbisPasses.h"
#include "circt/Dialect/LLHD/Transforms/Passes.h"
#include "circt/Dialect/MSFT/MSFTPasses.h"
#include "circt/Dialect/Pipeline/PipelinePasses.h"
Expand Down Expand Up @@ -56,6 +57,7 @@ inline void registerAllPasses() {
seq::registerPasses();
sv::registerPasses();
handshake::registerPasses();
ibis::registerPasses();
hw::registerPasses();
pipeline::registerPasses();
ssp::registerPasses();
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Ibis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ add_circt_dialect_library(CIRCTIbis
MLIRIR
CIRCTHW
)

add_subdirectory(Transforms)
15 changes: 15 additions & 0 deletions lib/Dialect/Ibis/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
add_circt_dialect_library(CIRCTIbisTransforms
IbisCallPrep.cpp

DEPENDS
CIRCTIbisTransformsIncGen

LINK_LIBS PUBLIC
CIRCTDC
CIRCTIbis
CIRCTHW
CIRCTSupport
MLIRIR
MLIRPass
MLIRTransformUtils
)
270 changes: 270 additions & 0 deletions lib/Dialect/Ibis/Transforms/IbisCallPrep.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
//===- IbisCallPrep.cpp - Implementation of call prep lowering ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PassDetails.h"

#include "circt/Dialect/Ibis/IbisDialect.h"
#include "circt/Dialect/Ibis/IbisOps.h"
#include "circt/Dialect/Ibis/IbisPasses.h"
#include "circt/Dialect/Ibis/IbisTypes.h"

#include "circt/Dialect/HW/HWTypes.h"
#include "circt/Support/BackedgeBuilder.h"
#include "circt/Support/ConversionPatterns.h"

#include "mlir/Transforms/DialectConversion.h"

using namespace circt;
using namespace ibis;

/// Build indexes to make lookups faster. Create the new argument types as well.
struct CallPrepPrecomputed {
CallPrepPrecomputed(ModuleOp mod);

// Lookup a class from its symbol.
DenseMap<StringAttr, ClassOp> classSymbols;

// Mapping of method to argument type.
DenseMap<SymbolRefAttr, std::pair<hw::StructType, Location>> argTypes;

// Lookup the class to which a particular instance (in a particular class) is
// referring.
DenseMap<std::pair<ClassOp, StringAttr>, ClassOp> instanceMap;

// Lookup an entry in instanceMap. If not found, return null.
ClassOp lookupNext(ClassOp scope, StringAttr instSym) const {
auto entry = instanceMap.find(std::make_pair(scope, instSym));
if (entry == instanceMap.end())
return {};
return entry->second;
}

// Given an instance path, get the class::func symbolref for it.
SymbolRefAttr resolveInstancePath(Operation *scope, SymbolRefAttr path) const;

// Utility function to create a symbolref to a method.
static SymbolRefAttr getSymbol(MethodOp method) {
ClassOp cls = method.getParentOp();
return SymbolRefAttr::get(cls.getSymNameAttr(),
{FlatSymbolRefAttr::get(method)});
}
};

CallPrepPrecomputed::CallPrepPrecomputed(ModuleOp mod) {
auto *ctxt = mod.getContext();

// Populate the class-symbol lookup table.
for (auto cls : mod.getOps<ClassOp>())
classSymbols[cls.getSymNameAttr()] = cls;

for (auto cls : mod.getOps<ClassOp>()) {
// Compute new argument types for each method.
for (auto method : cls.getOps<MethodOp>()) {

// Create the struct type.
SmallVector<hw::StructType::FieldInfo> argFields;
for (auto [argName, argType] :
llvm::zip(method.getArgNamesAttr().getAsRange<StringAttr>(),
method.getArgumentTypes()))
argFields.push_back({argName, argType});
auto argStruct = hw::StructType::get(ctxt, argFields);

// Later we're gonna want the block locations, so compute a fused location
// and store it.
Location argLoc = UnknownLoc::get(ctxt);
if (method->getNumRegions() > 0) {
SmallVector<Location> argLocs;
Block *body = &method.getBody().front();
for (auto arg : body->getArguments())
argLocs.push_back(arg.getLoc());
argLoc = FusedLoc::get(ctxt, argLocs);
}

// Add both to the lookup table.
argTypes.insert(
std::make_pair(getSymbol(method), std::make_pair(argStruct, argLoc)));
}

// Populate the instances table.
for (auto inst : cls.getOps<InstanceOp>()) {
auto clsEntry = classSymbols.find(inst.getClassNameAttr().getAttr());
assert(clsEntry != classSymbols.end() &&
"class being instantiated doesn't exist");
instanceMap[std::make_pair(cls, inst.getSymNameAttr())] =
clsEntry->second;
}
}
}

SymbolRefAttr
CallPrepPrecomputed::resolveInstancePath(Operation *scope,
SymbolRefAttr path) const {
auto cls = scope->getParentOfType<ClassOp>();
assert(cls && "scope outside of ibis class");

// SymbolRefAttr is rather silly. The start of the path is root reference...
cls = lookupNext(cls, path.getRootReference());
if (!cls)
return {};

// ... then the rest are the nested references. The last one is the function
// name rather than an instance.
for (auto instSym : path.getNestedReferences().drop_back()) {
cls = lookupNext(cls, instSym.getAttr());
if (!cls)
return {};
}

// The last one is the function symbol.
return SymbolRefAttr::get(cls.getSymNameAttr(),
{FlatSymbolRefAttr::get(path.getLeafReference())});
}

namespace {
/// For each CallOp, the corresponding method signature will have changed. Pack
/// all the operands into a struct.
struct MergeCallArgs : public OpConversionPattern<CallOp> {
MergeCallArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
: OpConversionPattern(ctxt), info(info) {}

void rewrite(CallOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
LogicalResult match(CallOp) const override { return success(); }

private:
const CallPrepPrecomputed &info;
};
} // anonymous namespace

void MergeCallArgs::rewrite(CallOp call, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = call.getLoc();
rewriter.setInsertionPoint(call);
auto func = call->getParentOfType<mlir::FunctionOpInterface>();

// Use the 'info' accelerator structures to find the argument type.
SymbolRefAttr calleeSym =
info.resolveInstancePath(func, adaptor.getCalleeAttr());
auto argStructEntry = info.argTypes.find(calleeSym);
assert(argStructEntry != info.argTypes.end() && "Method symref not found!");
auto [argStruct, argLoc] = argStructEntry->second;

// Pack all of the operands into it.
auto newArg = rewriter.create<hw::StructCreateOp>(loc, argStruct,
adaptor.getOperands());
newArg->setAttr("sv.namehint",
rewriter.getStringAttr(
call.getCalleeAttr().getLeafReference().getValue() +
"_args_called_from_" + func.getName()));

// Update the call to use just the new struct.
rewriter.updateRootInPlace(call, [&]() {
call.getOperandsMutable().clear();
call.getOperandsMutable().append(newArg.getResult());
});
}

namespace {
/// Change the method signatures to only have one argument: a struct capturing
/// all of the original arguments.
struct MergeMethodArgs : public OpConversionPattern<MethodOp> {
MergeMethodArgs(MLIRContext *ctxt, const CallPrepPrecomputed &info)
: OpConversionPattern(ctxt), info(info) {}

void rewrite(MethodOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final;
LogicalResult match(MethodOp) const override { return success(); }

private:
const CallPrepPrecomputed &info;
};
} // anonymous namespace

void MergeMethodArgs::rewrite(MethodOp func, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = func.getLoc();
auto *ctxt = getContext();

// Find the pre-computed arg struct for this method.
auto argStructEntry =
info.argTypes.find(CallPrepPrecomputed::getSymbol(func));
assert(argStructEntry != info.argTypes.end() && "Cannot find symref!");
auto [argStruct, argLoc] = argStructEntry->second;

// Create a new method with the new signature.
FunctionType funcType = func.getFunctionType();
FunctionType newFuncType =
FunctionType::get(ctxt, {argStruct}, funcType.getResults());
auto newArgNames = ArrayAttr::get(ctxt, {StringAttr::get(ctxt, "arg")});
auto newMethod =
rewriter.create<MethodOp>(loc, func.getSymNameAttr(), newFuncType,
newArgNames, ArrayAttr(), ArrayAttr());

if (func->getNumRegions() > 0) {
// Create a body block with a struct explode to the arg struct into the
// original arguments.
Block *b = rewriter.createBlock(&newMethod.getBodyRegion(), {}, {argStruct},
{argLoc});
rewriter.setInsertionPointToStart(b);
auto replacementArgs =
rewriter.create<hw::StructExplodeOp>(loc, b->getArgument(0));

// Merge the original method body, rewiring the args.
Block *funcBody = &func.getBody().front();
rewriter.mergeBlocks(funcBody, b, replacementArgs.getResults());
}

rewriter.eraseOp(func);
}

namespace {
/// Run all the physical lowerings.
struct CallPrepPass : public IbisCallPrepBase<CallPrepPass> {
void runOnOperation() override;

private:
// Merge the arguments into one struct.
LogicalResult merge(const CallPrepPrecomputed &);
};
} // anonymous namespace

void CallPrepPass::runOnOperation() {
CallPrepPrecomputed info(getOperation());

if (failed(merge(info))) {
signalPassFailure();
return;
}
}

LogicalResult CallPrepPass::merge(const CallPrepPrecomputed &info) {
// Set up a conversion and give it a set of laws.
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
target.addDynamicallyLegalOp<CallOp>([](CallOp call) {
auto argValues = call.getArgOperands();
return argValues.size() == 1 &&
hw::type_isa<hw::StructType>(argValues.front().getType());
});
target.addDynamicallyLegalOp<MethodOp>([](MethodOp func) {
ArrayRef<Type> argTypes = func.getFunctionType().getInputs();
return argTypes.size() == 1 &&
hw::type_isa<hw::StructType>(argTypes.front());
});

// Add patterns to merge the args on both the call and method sides.
RewritePatternSet patterns(&getContext());
patterns.insert<MergeCallArgs>(&getContext(), info);
patterns.insert<MergeMethodArgs>(&getContext(), info);

return applyPartialConversion(getOperation(), target, std::move(patterns));
}

std::unique_ptr<Pass> circt::ibis::createCallPrepPass() {
return std::make_unique<CallPrepPass>();
}
31 changes: 31 additions & 0 deletions lib/Dialect/Ibis/Transforms/PassDetails.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- PassDetails.h --------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// clang-tidy seems to expect the absolute path in the header guard on some
// systems, so just disable it.
// NOLINTNEXTLINE(llvm-header-guard)
#ifndef DIALECT_IBIS_TRANSFORMS_PASSDETAILS_H
#define DIALECT_IBIS_TRANSFORMS_PASSDETAILS_H

#include "circt/Dialect/DC/DCDialect.h"
#include "circt/Dialect/ESI/ESIDialect.h"
#include "circt/Dialect/Ibis/IbisOps.h"
#include "circt/Dialect/SV/SVDialect.h"

#include "mlir/Pass/Pass.h"

namespace circt {
namespace sandpiper {

#define GEN_PASS_CLASSES
#include "circt/Dialect/Ibis/Ibis.h.inc"

} // namespace sandpiper
} // namespace circt

#endif // DIALECT_IBIS_TRANSFORMS_PASSDETAILS_H
Loading

0 comments on commit d4b591e

Please sign in to comment.