Skip to content

Commit

Permalink
Misc changes to lowering to SPIR-V.
Browse files Browse the repository at this point in the history
These changes to SPIR-V lowering while adding support for lowering
SUbViewOp, but are not directly related.
- Change the lowering of MemRefType to
  !spv.ptr<!spv.struct<!spv.array<...>[offset]>, ..>
  This is consistent with the Vulkan spec.
- To enable testing a simple pattern of lowering functions is added to
  ConvertStandardToSPIRVPass. This is just used to convert the type of
  the arguments of the function. The added function lowering itself is
  not meant to be the way functions are eventually lowered into SPIR-V
  dialect.

PiperOrigin-RevId: 282589644
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Nov 26, 2019
1 parent 9059cf3 commit 03620fa
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 69 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
Expand Up @@ -39,6 +39,9 @@ class SPIRVTypeConverter final : public TypeConverter {

/// Converts types to SPIR-V types using the basic type converter.
Type convertType(Type t) override;

/// Gets the index type equivalent in SPIR-V.
Type getIndexType(MLIRContext *context);
};

/// Base class to define a conversion pattern to translate Ops into SPIR-V.
Expand Down
51 changes: 34 additions & 17 deletions mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineMap.h"
#include "llvm/ADT/SetVector.h"

using namespace mlir;
Expand Down Expand Up @@ -127,6 +128,32 @@ class IntegerOpConversion final : public SPIRVOpLowering<StdOp> {
}
};

// If 'basePtr' is the result of lowering a value of MemRefType, and 'indices'
// are the indices used to index into the original value (for load/store),
// perform the equivalent address calculation in SPIR-V.
spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
Value *basePtr, ArrayRef<Value *> indices,
SPIRVTypeConverter &typeConverter) {
// MemRefType is converted to a
// spirv::StructType<spirv::ArrayType<spirv:ArrayType...>>>
auto ptrType = basePtr->getType().cast<spirv::PointerType>();
(void)ptrType;
auto structType = ptrType.getPointeeType().cast<spirv::StructType>();
(void)structType;
assert(structType.getNumElements() == 1);
auto indexType = typeConverter.getIndexType(builder.getContext());

// Need to add a '0' at the beginning of the index list for accessing into the
// struct that wraps the nested array types.
Value *zero = builder.create<spirv::ConstantOp>(
loc, indexType, builder.getIntegerAttr(indexType, 0));
SmallVector<Value *, 4> accessIndices;
accessIndices.reserve(1 + indices.size());
accessIndices.push_back(zero);
accessIndices.append(indices.begin(), indices.end());
return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndices);
}

/// Convert load -> spv.LoadOp. The operands of the replaced operation are of
/// IndexType while that of the replacement operation are of type i32. This is
/// not supported in tablegen based pattern specification.
Expand All @@ -141,17 +168,11 @@ class LoadOpConversion final : public SPIRVOpLowering<LoadOp> {
ConversionPatternRewriter &rewriter) const override {
LoadOpOperandAdaptor loadOperands(operands);
auto basePtr = loadOperands.memref();
auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return matchFailure();
}
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
loadOp.getLoc(), basePtr, loadOperands.indices());
auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>();
rewriter.replaceOpWithNewOp<spirv::LoadOp>(
loadOp, loadPtrType.getPointeeType(), loadPtr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
auto loadPtr = getElementPtr(rewriter, loadOp.getLoc(), basePtr,
loadOperands.indices(), typeConverter);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
return matchSuccess();
}
};
Expand Down Expand Up @@ -202,12 +223,8 @@ class StoreOpConversion final : public SPIRVOpLowering<StoreOp> {
StoreOpOperandAdaptor storeOperands(operands);
auto value = storeOperands.value();
auto basePtr = storeOperands.memref();
auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>();
if (!ptrType) {
return matchFailure();
}
auto storePtr = rewriter.create<spirv::AccessChainOp>(
storeOp.getLoc(), basePtr, storeOperands.indices());
auto storePtr = getElementPtr(rewriter, storeOp.getLoc(), basePtr,
storeOperands.indices(), typeConverter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr, value,
/*memory_access =*/nullptr,
/*alignment =*/nullptr);
Expand Down
44 changes: 42 additions & 2 deletions mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.cpp
Expand Up @@ -29,22 +29,62 @@
using namespace mlir;

namespace {

/// A simple pattern for rewriting function signature to convert arguments of
/// functions to be of valid SPIR-V types.
class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
public:
using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;

PatternMatchResult
matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};

/// A pass converting MLIR Standard operations into the SPIR-V dialect.
class ConvertStandardToSPIRVPass
: public ModulePass<ConvertStandardToSPIRVPass> {
void runOnModule() override;
};
} // namespace

PatternMatchResult
FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
return matchFailure();
}

TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
{
for (auto argType : enumerate(funcOp.getType().getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
signatureConverter.addInputs(argType.index(), convertedType);
}
}
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
return matchSuccess();
}

void ConvertStandardToSPIRVPass::runOnModule() {
OwningRewritePatternList patterns;
auto context = &getContext();
auto module = getModule();

SPIRVTypeConverter typeConverter;
populateStandardToSPIRVPatterns(module.getContext(), typeConverter, patterns);
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
patterns.insert<FuncOpConversion>(context, typeConverter);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });

if (failed(applyPartialConversion(module, target, patterns))) {
return signalPassFailure();
Expand Down
58 changes: 52 additions & 6 deletions mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
Expand Up @@ -80,6 +80,18 @@ Type convertIndexType(MLIRContext *context) {
return IntegerType::get(32, context);
}

// TODO(ravishankarm): This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
Optional<int64_t> getTypeNumBytes(Type t) {
if (auto integerType = t.dyn_cast<IntegerType>()) {
return integerType.getWidth() / 8;
} else if (auto floatType = t.dyn_cast<FloatType>()) {
return floatType.getWidth() / 8;
}
// TODO: Add size computation for other types.
return llvm::None;
}

Type typeConversionImpl(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirv::SPIRVDialect::isValidType(t)) {
Expand All @@ -91,16 +103,46 @@ Type typeConversionImpl(Type t) {
}

if (auto memRefType = t.dyn_cast<MemRefType>()) {
auto elementType = memRefType.getElementType();
// TODO(ravishankarm) : Handle dynamic shapes and memref with strides.
if (memRefType.hasStaticShape() && memRefType.getAffineMaps().empty()) {
// TODO(ravishankarm): For now only support default memory space. The memory
// space description is not set is stone within MLIR, i.e. it depends on the
// context it is being used. To map this to SPIR-V storage classes, we
// should rely on the ABI attributes, and not on the memory space. This is
// still evolving, and needs to be revisited when there is more clarity.
if (memRefType.getMemorySpace()) {
return Type();
}
auto elementType = typeConversionImpl(memRefType.getElementType());
if (!elementType) {
return Type();
}
auto elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
return Type();
}
// TODO(ravishankarm) : Handle dynamic shapes.
if (memRefType.hasStaticShape()) {
// Get the strides and offset
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
// TODO(ravishankarm) : Handle dynamic strides and offsets.
return Type();
}
// Convert to a multi-dimensional spv.array if size is known.
for (auto size : reverse(memRefType.getShape())) {
elementType = spirv::ArrayType::get(elementType, size);
auto shape = memRefType.getShape();
assert(shape.size() == strides.size());
for (int i = shape.size(); i > 0; --i) {
elementType = spirv::ArrayType::get(
elementType, shape[i - 1], strides[i - 1] * elementSize.getValue());
}
// For the offset, need to wrap the array in a struct.
auto structType =
spirv::StructType::get(elementType, offset * elementSize.getValue());
// For now initialize the storage class to StorageBuffer. This will be
// updated later based on whats passed in w.r.t to the ABI attributes.
return spirv::PointerType::get(elementType,
return spirv::PointerType::get(structType,
spirv::StorageClass::StorageBuffer);
}
}
Expand All @@ -110,6 +152,10 @@ Type typeConversionImpl(Type t) {

Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }

Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
return convertType(IndexType::get(context));
}

//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
Expand Up @@ -57,10 +57,6 @@ createGlobalVariableForArg(FuncOp funcOp, OpBuilder &builder, unsigned argNum,
if (isScalarOrVectorType(varType)) {
varType =
spirv::PointerType::get(spirv::StructType::get(varType), storageClass);
} else {
auto varPtrType = varType.cast<spirv::PointerType>();
varType = spirv::PointerType::get(
spirv::StructType::get(varPtrType.getPointeeType()), storageClass);
}
auto varPtrType = varType.cast<spirv::PointerType>();
auto varPointeeType = varPtrType.getPointeeType().cast<spirv::StructType>();
Expand Down Expand Up @@ -180,25 +176,29 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
}
auto var =
createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo);
if (!var) {
return matchFailure();
}

OpBuilder::InsertionGuard funcInsertionGuard(rewriter);
rewriter.setInsertionPointToStart(&funcOp.front());
// Inserts spirv::AddressOf and spirv::AccessChain operations.
auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
auto indexType =
typeConverter.convertType(IndexType::get(funcOp.getContext()));
auto zero = rewriter.create<spirv::ConstantOp>(
funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
Value *replacement = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), addressOf.pointer(), zero.constant());
// Insert spirv::AddressOf and spirv::AccessChain operations.
Value *replacement =
rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
// Check if the arg is a scalar or vector type. In that case, the value
// needs to be loaded into registers.
// TODO(ravishankarm) : This is loading value of the scalar into registers
// at the start of the function. It is probably better to do the load just
// before the use. There might be multiple loads and currently there is no
// easy way to replace all uses with a sequence of operations.
if (isScalarOrVectorType(argType.value())) {
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), replacement,
auto indexType =
typeConverter.convertType(IndexType::get(funcOp.getContext()));
auto zero = rewriter.create<spirv::ConstantOp>(
funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,
/*memory_access=*/nullptr,
/*alignment=*/nullptr);
}
Expand Down
15 changes: 9 additions & 6 deletions mlir/test/Conversion/GPUToSPIRV/load-store.mlir
Expand Up @@ -22,9 +22,9 @@ module attributes {gpu.container_module} {
// CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
// CHECK-LABEL: func @load_store_kernel
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.array<12 x !spv.array<4 x f32>>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32 [4]> [16]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
Expand Down Expand Up @@ -53,15 +53,18 @@ module attributes {gpu.container_module} {
%12 = addi %arg3, %0 : index
// CHECK: [[INDEX2:%.*]] = spv.IAdd [[ARG4]], [[LOCALINVOCATIONIDX]]
%13 = addi %arg4, %3 : index
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
// CHECK: [[ZERO1:%.*]] = spv.constant 0 : i32
// CHECK: [[PTR1:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[ZERO1]], [[INDEX1]], [[INDEX2]]{{\]}}
// CHECK-NEXT: [[VAL1:%.*]] = spv.Load "StorageBuffer" [[PTR1]]
%14 = load %arg0[%12, %13] : memref<12x4xf32>
// CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
// CHECK: [[ZERO2:%.*]] = spv.constant 0 : i32
// CHECK: [[PTR2:%.*]] = spv.AccessChain [[ARG1]]{{\[}}[[ZERO2]], [[INDEX1]], [[INDEX2]]{{\]}}
// CHECK-NEXT: [[VAL2:%.*]] = spv.Load "StorageBuffer" [[PTR2]]
%15 = load %arg1[%12, %13] : memref<12x4xf32>
// CHECK: [[VAL3:%.*]] = spv.FAdd [[VAL1]], [[VAL2]]
%16 = addf %14, %15 : f32
// CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
// CHECK: [[ZERO3:%.*]] = spv.constant 0 : i32
// CHECK: [[PTR3:%.*]] = spv.AccessChain [[ARG2]]{{\[}}[[ZERO3]], [[INDEX1]], [[INDEX2]]{{\]}}
// CHECK-NEXT: spv.Store "StorageBuffer" [[PTR3]], [[VAL3]]
store %16, %arg2[%12, %13] : memref<12x4xf32>
return
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/GPUToSPIRV/loop.mlir
Expand Up @@ -22,8 +22,8 @@ module attributes {gpu.container_module} {
// CHECK: [[CMP:%.*]] = spv.SLessThan [[INDVAR]], [[UB]] : i32
// CHECK: spv.BranchConditional [[CMP]], [[BODY:\^.*]], [[MERGE:\^.*]]
// CHECK: [[BODY]]:
// CHECK: spv.AccessChain {{%.*}}{{\[}}[[INDVAR]]{{\]}} : {{.*}}
// CHECK: spv.AccessChain {{%.*}}{{\[}}[[INDVAR]]{{\]}} : {{.*}}
// CHECK: spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[INDVAR]]{{\]}} : {{.*}}
// CHECK: spv.AccessChain {{%.*}}{{\[}}{{%.*}}, [[INDVAR]]{{\]}} : {{.*}}
// CHECK: [[INCREMENT:%.*]] = spv.IAdd [[INDVAR]], [[STEP]] : i32
// CHECK: spv.Branch [[HEADER]]([[INCREMENT]] : i32)
// CHECK: [[MERGE]]
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/GPUToSPIRV/simple.mlir
Expand Up @@ -6,8 +6,8 @@ module attributes {gpu.container_module} {
// CHECK: spv.module "Logical" "GLSL450" {
// CHECK-LABEL: func @kernel_1
// CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.array<12 x f32>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32, 1>)
// CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32>)
attributes { gpu.kernel } {
// CHECK: spv.Return
return
Expand All @@ -16,10 +16,10 @@ module attributes {gpu.container_module} {

func @foo() {
%0 = "op"() : () -> (f32)
%1 = "op"() : () -> (memref<12xf32, 1>)
%1 = "op"() : () -> (memref<12xf32>)
%cst = constant 1 : index
"gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = "kernel_1", kernel_module = @kernels }
: (index, index, index, index, index, index, f32, memref<12xf32, 1>) -> ()
: (index, index, index, index, index, index, f32, memref<12xf32>) -> ()
return
}
}

0 comments on commit 03620fa

Please sign in to comment.