Skip to content

Commit

Permalink
[mlir][spirv] Use SingleBlock + NoTerminator for spv.module
Browse files Browse the repository at this point in the history
This allows us to remove the `spv.mlir.endmodule` op and
all the code associated with it.

Along the way, tightened the APIs for `spv.module` a bit
by removing some aliases. Now we use `getRegion` to get
the only region, and `getBody` to get the region's only
block.

Reviewed By: mravishankar, hanchung

Differential Revision: https://reviews.llvm.org/D103265
  • Loading branch information
antiagainst committed Jun 9, 2021
1 parent 64b2fb7 commit 56f60a1
Show file tree
Hide file tree
Showing 15 changed files with 45 additions and 141 deletions.
4 changes: 2 additions & 2 deletions mlir/docs/Dialects/SPIR-V.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ The SPIR-V dialect adopts the following conventions for IR:
(de)serialization.
* Ops with `mlir.snake_case` names are those that have no corresponding
instructions (or concepts) in the binary format. They are introduced to
satisfy MLIR structural requirements. For example, `spv.mlir.endmodule` and
`spv.mlir.merge`. They map to no instructions during (de)serialization.
satisfy MLIR structural requirements. For example, `spv.mlir.merge`. They
map to no instructions during (de)serialization.

(TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for
them.)
Expand Down
2 changes: 0 additions & 2 deletions mlir/docs/SPIRVToLLVMDialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,6 @@ Module in SPIR-V has one region that contains one block. It is defined via
`spv.module` is converted into `ModuleOp`. This plays a role of enclosing scope
to LLVM ops. At the moment, SPIR-V module attributes are ignored.

`spv.mlir.endmodule` is mapped to an equivalent terminator `ModuleTerminatorOp`.

## `mlir-spirv-cpu-runner`

`mlir-spirv-cpu-runner` allows to execute `gpu` dialect kernel on the CPU via
Expand Down
38 changes: 4 additions & 34 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,8 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
// -----

def SPV_ModuleOp : SPV_Op<"module",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ModuleEndOp">,
SymbolTable, Symbol]> {
[IsolatedFromAbove, NoRegionArguments, NoTerminator,
SingleBlock, SymbolTable, Symbol]> {
let summary = "The top-level op that defines a SPIR-V module";

let description = [{
Expand All @@ -426,7 +425,7 @@ def SPV_ModuleOp : SPV_Op<"module",
implicitly capture values from the enclosing environment.

This op has only one region, which only contains one block. The block
must be terminated via the `spv.mlir.endmodule` op.
has no terminator.

<!-- End of AutoGen section -->

Expand Down Expand Up @@ -463,7 +462,7 @@ def SPV_ModuleOp : SPV_Op<"module",

let results = (outs);

let regions = (region SizedRegion<1>:$body);
let regions = (region AnyRegion);

let builders = [
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
Expand All @@ -487,40 +486,11 @@ def SPV_ModuleOp : SPV_Op<"module",
Optional<StringRef> getName() { return sym_name(); }

static StringRef getVCETripleAttrName() { return "vce_triple"; }

Block& getBlock() {
return this->getOperation()->getRegion(0).front();
}
}];
}

// -----

def SPV_ModuleEndOp : SPV_Op<"mlir.endmodule", [InModuleScope, Terminator]> {
let summary = "The pseudo op that ends a SPIR-V module";

let description = [{
This op terminates the only block inside a `spv.module`'s only region.
This op does not have a corresponding SPIR-V instruction and thus will
not be serialized into the binary format; it is used solely to satisfy
the structual requirement that an block must be ended with a terminator.
}];

let arguments = (ins);

let results = (outs);

let assemblyFormat = "attr-dict";

let verifier = [{ return success(); }];

let hasOpcode = 0;

let autogenSerialization = 0;
}

// -----

def SPV_ReferenceOfOp : SPV_Op<"mlir.referenceof", [NoSideEffect]> {
let summary = "Reference a specialization constant.";

Expand Down
5 changes: 0 additions & 5 deletions mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
set(LLVM_TARGET_DEFINITIONS GPUToSPIRV.td)
mlir_tablegen(GPUToSPIRV.cpp.inc -gen-rewriters)
add_public_tablegen_target(MLIRGPUToSPIRVIncGen)

add_mlir_conversion_library(MLIRGPUToSPIRV
GPUToSPIRV.cpp
GPUToSPIRVPass.cpp

DEPENDS
MLIRConversionPassIncGen
MLIRGPUToSPIRVIncGen

LINK_LIBS PUBLIC
MLIRGPU
Expand Down
27 changes: 17 additions & 10 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
ConversionPatternRewriter &rewriter) const override;
};

class GPUModuleEndConversion final
: public OpConversionPattern<gpu::ModuleEndOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(endOp);
return success();
}
};

/// Pattern to convert a gpu.return into a SPIR-V return.
// TODO: This can go to DRR when GPU return has operands.
class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
Expand Down Expand Up @@ -301,12 +314,10 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
StringRef(spvModuleName));

// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.body();
Region &spvModuleRegion = spvModule.getRegion();
rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
spvModuleRegion.begin());
// The spv.module build method adds a block with a terminator. Remove that
// block. The terminator of the module op in the remaining block will be
// legalized later.
// The spv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());
rewriter.eraseOp(moduleOp);
return success();
Expand All @@ -330,15 +341,11 @@ LogicalResult GPUReturnOpConversion::matchAndRewrite(
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//

namespace {
#include "GPUToSPIRV.cpp.inc"
}

void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
populateWithGenerated(patterns);
patterns.add<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
GPUReturnOpConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::ThreadIdOp,
Expand Down
22 changes: 0 additions & 22 deletions mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.td

This file was deleted.

19 changes: 2 additions & 17 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {

auto newModuleOp =
rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());

// Remove the terminator block that was automatically added by builder
rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
Expand All @@ -1351,20 +1351,6 @@ class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
}
};

class ModuleEndConversionPattern
: public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
public:
using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;

LogicalResult
matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {

rewriter.eraseOp(moduleEndOp);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1507,8 +1493,7 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(

void mlir::populateSPIRVToLLVMModuleConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<ModuleConversionPattern, ModuleEndConversionPattern>(
patterns.getContext(), typeConverter);
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
}

//===----------------------------------------------------------------------===//
Expand Down
16 changes: 10 additions & 6 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2529,7 +2529,8 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {

void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> name) {
ensureTerminator(*state.addRegion(), builder, state.location);
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
Expand All @@ -2545,7 +2546,8 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
state.addAttribute("memory_model", builder.getI32IntegerAttr(
static_cast<int32_t>(memoryModel)));
ensureTerminator(*state.addRegion(), builder, state.location);
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
Expand Down Expand Up @@ -2581,7 +2583,10 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
return failure();

spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
// Make sure we have at least one block.
if (body->empty())
body->push_back(new Block());

return success();
}

Expand All @@ -2608,8 +2613,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
}

printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
printer.printRegion(moduleOp.getRegion());
}

static LogicalResult verify(spirv::ModuleOp moduleOp) {
Expand All @@ -2619,7 +2623,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
entryPoints;
SymbolTable table(moduleOp);

for (auto &op : moduleOp.getBlock()) {
for (auto &op : *moduleOp.getBody()) {
if (op.getDialect() != dialect)
return op.emitError("'spv.module' can only contain spv.* ops");

Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,

auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
modules[0].getLoc(), addressingModel, memoryModel);
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());

// In some cases, a symbol in the (current state of the) combined module is
// renamed in order to maintain the conflicting symbol in the input module
Expand All @@ -160,7 +160,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
// for spv.funcs. This way, if the conflicting op in the input module is
// non-spv.func, we rename that symbol instead and maintain the spv.func in
// the combined module name as it is.
for (auto &op : combinedModule.getBlock().without_terminator()) {
for (auto &op : *combinedModule.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();

Expand Down Expand Up @@ -195,7 +195,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,

// In the current input module, rename all symbols that conflict with
// symbols from the combined module. This includes renaming spv.funcs.
for (auto &op : moduleClone.getBlock().without_terminator()) {
for (auto &op : *moduleClone.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();

Expand Down Expand Up @@ -225,15 +225,15 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
}

// Clone all the module's ops to the combined module.
for (auto &op : moduleClone.getBlock().without_terminator())
for (auto &op : *moduleClone.getBody())
combinedModuleBuilder.insert(op.clone());
}

// Deduplicate identical global variables, spec constants, and functions.
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
SmallVector<SymbolOpInterface, 0> eraseList;

for (auto &op : combinedModule.getBlock().without_terminator()) {
for (auto &op : *combinedModule.getBody()) {
llvm::hash_code hashCode(0);
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,

OpBuilder::InsertionGuard moduleInsertionGuard(builder);
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
builder.setInsertionPoint(spirvModule.body().front().getTerminator());
builder.setInsertionPointToEnd(spirvModule.getBody());

// Adds the spv.EntryPointOp after collecting all the interface variables
// needed.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static inline bool isFnEntryBlock(Block *block) {
spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
MLIRContext *context)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
module(createModuleOp()), opBuilder(module->body()) {}
module(createModuleOp()), opBuilder(module->getRegion()) {}

LogicalResult spirv::Deserializer::deserialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ LogicalResult Serializer::serialize() {

// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : module.getBlock()) {
for (auto &op : *module.getBody()) {
if (failed(processOperation(&op))) {
return failure();
}
Expand Down Expand Up @@ -1090,7 +1090,6 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ModuleEndOp) { return success(); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
Expand Down
6 changes: 0 additions & 6 deletions mlir/test/Conversion/SPIRVToLLVM/module-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ spv.module @foo Logical GLSL450 {}
// CHECK: module
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}

// CHECK: module
spv.module Logical GLSL450 {
// CHECK: }
spv.mlir.endmodule
}

// CHECK: module
spv.module Logical GLSL450 {
// CHECK-LABEL: llvm.func @empty()
Expand Down
Loading

0 comments on commit 56f60a1

Please sign in to comment.