Skip to content

Commit

Permalink
[mlir][spirv] Add a pass to deduce version/extension/capability
Browse files Browse the repository at this point in the history
Creates an operation pass that deduces and attaches the minimal version/
capabilities/extensions requirements for spv.module ops.

For each spv.module op, this pass requires a `spv.target_env` attribute on
it or an enclosing module-like op to drive the deduction. The reason is
that an op can be enabled by multiple extensions/capabilities. So we need
to know which one to pick. `spv.target_env` gives the hard limit as for
what the target environment can support; this pass deduces what are
actually needed for a specific spv.module op.

Differential Revision: https://reviews.llvm.org/D75870
  • Loading branch information
antiagainst committed Mar 12, 2020
1 parent 66c378d commit 9414db1
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 13 deletions.
24 changes: 18 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,24 @@ class ModuleOp;
std::unique_ptr<OpPassBase<mlir::ModuleOp>>
createDecorateSPIRVCompositeTypeLayoutPass();

/// Creates a module pass that lowers the ABI attributes specified during SPIR-V
/// Lowering. Specifically,
/// 1) Creates the global variables for arguments of entry point function using
/// the specification in the ABI attributes for each argument.
/// 2) Inserts the EntryPointOp and the ExecutionModeOp for entry point
/// functions using the specification in the EntryPointAttr.
/// Creates an operation pass that deduces and attaches the minimal version/
/// capabilities/extensions requirements for spv.module ops.
/// For each spv.module op, this pass requires a `spv.target_env` attribute on
/// it or an enclosing module-like op to drive the deduction. The reason is
/// that an op can be enabled by multiple extensions/capabilities. So we need
/// to know which one to pick. `spv.target_env` gives the hard limit as for
/// what the target environment can support; this pass deduces what are
/// actually needed for a specific spv.module op.
std::unique_ptr<OpPassBase<spirv::ModuleOp>>
createUpdateVersionCapabilityExtensionPass();

/// Creates an operation pass that lowers the ABI attributes specified during
/// SPIR-V Lowering. Specifically,
/// 1. Creates the global variables for arguments of entry point function using
/// the specification in the `spv.interface_var_abi` attribute for each
/// argument.
/// 2. Inserts the EntryPointOp and the ExecutionModeOp for entry point
/// functions using the specification in the `spv.entry_point_abi` attribute.
std::unique_ptr<OpPassBase<spirv::ModuleOp>> createLowerABIAttributesPass();

} // namespace spirv
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ StringRef getTargetEnvAttrName();
/// and no extra extensions.
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);

/// Queries the target environment recursively from enclosing symbol table ops
/// containing the given `op`.
TargetEnvAttr lookupTargetEnv(Operation *op);

/// Queries the target environment recursively from enclosing symbol table ops
/// containing the given `op` or returns the default target environment as
/// returned by getDefaultTargetEnv() if not provided.
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/InitAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ inline void registerAllPasses() {
// SPIR-V
spirv::createDecorateSPIRVCompositeTypeLayoutPass();
spirv::createLowerABIAttributesPass();
spirv::createUpdateVersionCapabilityExtensionPass();
createConvertGPUToSPIRVPass();
createConvertStandardToSPIRVPass();
createLegalizeStdOpsForSPIRVLoweringPass();
Expand Down
20 changes: 13 additions & 7 deletions mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,19 +294,25 @@ spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) {
spirv::getDefaultResourceLimits(context));
}

spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
Operation *symTable = op;
while (symTable) {
symTable = SymbolTable::getNearestSymbolTable(symTable);
if (!symTable)
spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
if (!op)
break;

if (auto attr = symTable->getAttrOfType<spirv::TargetEnvAttr>(
if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
return attr;

symTable = symTable->getParentOp();
op = op->getParentOp();
}

return {};
}

spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) {
if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op))
return attr;

return getDefaultTargetEnv(op->getContext());
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRSPIRVTransforms
DecorateSPIRVCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
UpdateVCEPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
Expand Down
164 changes: 164 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to deduce minimal version/extension/capability
// requirements for a spirv::ModuleOp.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"

using namespace mlir;

namespace {
/// Pass to deduce minimal version/extension/capability requirements for a
/// spirv::ModuleOp.
class UpdateVCEPass final
: public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
private:
void runOnOperation() override;
};
} // namespace

void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();

spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
if (!targetEnv) {
module.emitError("missing 'spv.target_env' attribute");
return signalPassFailure();
}

spirv::Version allowedVersion = targetEnv.getVersion();

// Build a set for available extensions in the target environment.
llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
for (spirv::Extension ext : targetEnv.getExtensions())
allowedExtensions.insert(ext);

// Add extensions implied by the current version.
for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
allowedExtensions.insert(ext);

// Build a set for available capabilities in the target environment.
llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
for (spirv::Capability cap : targetEnv.getCapabilities()) {
allowedCapabilities.insert(cap);

// Add capabilities implied by the current capability.
for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
allowedCapabilities.insert(c);
}

spirv::Version deducedVersion = spirv::Version::V_1_0;
llvm::SetVector<spirv::Extension> deducedExtensions;
llvm::SetVector<spirv::Capability> deducedCapabilities;

// Walk each SPIR-V op to deduce the minimal version/extension/capability
// requirements.
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
if (deducedVersion > allowedVersion) {
return op->emitError("'") << op->getName() << "' requires min version "
<< spirv::stringifyVersion(deducedVersion)
<< " but target environment allows up to "
<< spirv::stringifyVersion(allowedVersion);
}
}

// Deduce this op's extension requirement. For each op, the query interfacce
// returns a vector of vector for its extension requirements following
// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) {
for (const auto &ors : extensions.getExtensions()) {
bool satisfied = false; // True when at least one extension can be used
for (spirv::Extension ext : ors) {
if (allowedExtensions.count(ext)) {
deducedExtensions.insert(ext);
satisfied = true;
break;
}
}

if (!satisfied) {
SmallVector<StringRef, 4> extStrings;
for (spirv::Extension ext : ors)
extStrings.push_back(spirv::stringifyExtension(ext));

return op->emitError("'")
<< op->getName() << "' requires at least one extension in ["
<< llvm::join(extStrings, ", ")
<< "] but none allowed in target environment";
}
}
}

// Deduce this op's capability requirement. For each op, the queryinterface
// returns a vector of vector for its capability requirements following
// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D))
// convention. Ops not implementing QueryExtensionInterface do not require
// extensions to be available.
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
for (const auto &ors : capabilities.getCapabilities()) {
bool satisfied = false; // True when at least one capability can be used
for (spirv::Capability cap : ors) {
if (allowedCapabilities.count(cap)) {
deducedCapabilities.insert(cap);
satisfied = true;
break;
}
}

if (!satisfied) {
SmallVector<StringRef, 4> capStrings;
for (spirv::Capability cap : ors)
capStrings.push_back(spirv::stringifyCapability(cap));

return op->emitError("'")
<< op->getName() << "' requires at least one capability in ["
<< llvm::join(capStrings, ", ")
<< "] but none allowed in target environment";
}
}
}

return WalkResult::advance();
});

if (walkResult.wasInterrupted())
return signalPassFailure();

// TODO(antiagainst): verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.

auto triple = spirv::VerCapExtAttr::get(
deducedVersion, deducedCapabilities.getArrayRef(),
deducedExtensions.getArrayRef(), &getContext());
module.setAttr("vce_triple", triple);
}

std::unique_ptr<OpPassBase<spirv::ModuleOp>>
mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
return std::make_unique<UpdateVCEPass>();
}

static PassRegistration<UpdateVCEPass>
pass("spirv-update-vce",
"Deduce and attach minimal (version, capabilities, extensions) "
"requirements to spv.module ops");
Loading

0 comments on commit 9414db1

Please sign in to comment.