Skip to content

Commit

Permalink
[mlir][spirv] Lower allocation/deallocations of workgroup memory.
Browse files Browse the repository at this point in the history
This allocation of a workgroup memory is lowered to a
spv.globalVariable. Only static size allocation with element type
being int or float is handled. The lowering does account for the
element type that are not supported in the lowered spv.module based on
the extensions/capabilities and adjusts the number of elements to get
the same byte length.

Differential Revision: https://reviews.llvm.org/D80411
  • Loading branch information
MaheshRavishankar committed May 27, 2020
1 parent 29f8056 commit 4d6f44f
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 47 deletions.
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class SPIRVTypeConverter : public TypeConverter {
public:
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);

/// Gets the number of bytes used for a type when converted to SPIR-V
/// type. Note that it doesnt account for whether the type is legal for a
/// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on
/// failure.
static Optional<int64_t> getConvertedTypeNumBytes(Type);

/// Gets the SPIR-V correspondence for the standard index type.
static Type getIndexType(MLIRContext *context);

Expand Down
128 changes: 111 additions & 17 deletions mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,51 @@ bool isUnsignedOp() {
return true; \
}

CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp);
CHECK_UNSIGNED_OP(spirv::AtomicUMinOp);
CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp);
CHECK_UNSIGNED_OP(spirv::ConvertUToFOp);
CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp);
CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp);
CHECK_UNSIGNED_OP(spirv::UConvertOp);
CHECK_UNSIGNED_OP(spirv::UDivOp);
CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp);
CHECK_UNSIGNED_OP(spirv::UGreaterThanOp);
CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp);
CHECK_UNSIGNED_OP(spirv::ULessThanOp);
CHECK_UNSIGNED_OP(spirv::UModOp);
CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
CHECK_UNSIGNED_OP(spirv::UConvertOp)
CHECK_UNSIGNED_OP(spirv::UDivOp)
CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
CHECK_UNSIGNED_OP(spirv::ULessThanOp)
CHECK_UNSIGNED_OP(spirv::UModOp)

#undef CHECK_UNSIGNED_OP

/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
// Currently only support workgroup local memory allocations with static
// shape and int or float element type.
return t.hasStaticShape() &&
SPIRVTypeConverter::getMemorySpaceForStorageClass(
spirv::StorageClass::Workgroup) == t.getMemorySpace() &&
t.getElementType().isIntOrFloat();
}

/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns None on failure.
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
Optional<spirv::StorageClass> storageClass =
SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
if (!storageClass)
return {};
switch (*storageClass) {
case spirv::StorageClass::StorageBuffer:
return spirv::Scope::Device;
case spirv::StorageClass::Workgroup:
return spirv::Scope::Workgroup;
default: {
}
}
return {};
}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
Expand All @@ -195,6 +224,67 @@ CHECK_UNSIGNED_OP(spirv::UModOp);

namespace {

/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant. Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil
/// ladd global variables into the spv.module.
class AllocOpPattern final : public SPIRVOpLowering<AllocOp> {
public:
using SPIRVOpLowering<AllocOp>::SPIRVOpLowering;

LogicalResult
matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType allocType = operation.getType();
if (!isAllocationSupported(allocType))
return operation.emitError("unhandled allocation type");

// Get the SPIR-V type for the allocation.
Type spirvType = typeConverter.convertType(allocType);

// Insert spv.globalVariable for this allocation.
Operation *parent =
SymbolTable::getNearestSymbolTable(operation.getParentOp());
if (!parent)
return failure();
Location loc = operation.getLoc();
spirv::GlobalVariableOp varOp;
{
OpBuilder::InsertionGuard guard(rewriter);
Block &entryBlock = *parent->getRegion(0).begin();
rewriter.setInsertionPointToStart(&entryBlock);
auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
std::string varName =
std::string("__workgroup_mem__") +
std::to_string(std::distance(varOps.begin(), varOps.end()));
varOp = rewriter.create<spirv::GlobalVariableOp>(
loc, TypeAttr::get(spirvType), varName,
/*initializer = */ nullptr);
}

// Get pointer to global variable at the current scope.
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
return success();
}
};

/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public SPIRVOpLowering<DeallocOp> {
public:
using SPIRVOpLowering<DeallocOp>::SPIRVOpLowering;

LogicalResult
matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
if (!isAllocationSupported(deallocType))
return operation.emitError("unhandled deallocation type");
rewriter.eraseOp(operation);
return success();
}
};

/// Converts unary and binary standard operations to SPIR-V operations.
template <typename StdOp, typename SPIRVOp>
class UnaryAndBinaryOpPattern final : public SPIRVOpLowering<StdOp> {
Expand Down Expand Up @@ -823,12 +913,15 @@ IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
if (!scope)
return failure();
Value result = rewriter.create<spirv::AtomicAndOp>(
loc, dstType, adjustedPtr, spirv::Scope::Device,
spirv::MemorySemantics::AcquireRelease, clearBitsMask);
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
clearBitsMask);
result = rewriter.create<spirv::AtomicOrOp>(
loc, dstType, adjustedPtr, spirv::Scope::Device,
spirv::MemorySemantics::AcquireRelease, storeVal);
loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
storeVal);

// The AtomicOrOp has no side effect. Since it is already inserted, we can
// just remove the original StoreOp. Note that rewriter.replaceOp()
Expand Down Expand Up @@ -913,6 +1006,7 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
AllocOpPattern, DeallocOpPattern,
BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern,
Expand Down
50 changes: 31 additions & 19 deletions mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}

Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
return getTypeNumBytes(t);
}

/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
static Optional<Type>
convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
Expand Down Expand Up @@ -383,8 +387,11 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
auto arrayType =
spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);

// Wrap in a struct to satisfy Vulkan interface requirements.
auto structType = spirv::StructType::get(arrayType, 0);
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
// workgroup storage class do not need the struct to be laid out explicitly.
auto structType = *storageClass == spirv::StorageClass::Workgroup
? spirv::StructType::get(arrayType)
: spirv::StructType::get(arrayType, 0);
return spirv::PointerType::get(structType, *storageClass);
}

Expand Down Expand Up @@ -574,35 +581,40 @@ spirv::AccessChainOp mlir::spirv::getElementPtr(
SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
ArrayRef<Value> indices, Location loc, OpBuilder &builder) {
// Get base and offset of the MemRefType and verify they are static.

int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(baseType, strides, offset)) ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
offset == MemRefType::getDynamicStrideOrOffset()) {
return nullptr;
}

auto indexType = typeConverter.getIndexType(builder.getContext());

Value ptrLoc = nullptr;
assert(indices.size() == strides.size() &&
"must provide indices for all dimensions");
for (auto index : enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
ptrLoc =
(ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
: update);
}
SmallVector<Value, 2> linearizedIndices;
// Add a '0' at the start to index into the struct.
auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
linearizedIndices.push_back(zero);
// If it is a zero-rank memref type, extract the element directly.
if (!ptrLoc) {
ptrLoc = zero;

if (baseType.getRank() == 0) {
linearizedIndices.push_back(zero);
} else {
// TODO: Instead of this logic, use affine.apply and add patterns for
// lowering affine.apply to standard ops. These will get lowered to SPIR-V
// ops by the DialectConversion framework.
Value ptrLoc = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, offset));
assert(indices.size() == strides.size() &&
"must provide indices for all dimensions");
for (auto index : enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
Value update =
builder.create<spirv::IMulOp>(loc, strideVal, index.value());
ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
}
linearizedIndices.push_back(ptrLoc);
}
linearizedIndices.push_back(ptrLoc);
return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
}

Expand Down
8 changes: 5 additions & 3 deletions mlir/test/Conversion/GPUToSPIRV/load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@ module attributes {
%12 = addi %arg3, %0 : index
// CHECK: %[[INDEX2:.*]] = spv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spv.constant 0 : i32
// CHECK: %[[OFFSET1_0:.*]] = spv.constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spv.constant 4 : i32
// CHECK: %[[OFFSET1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
// CHECK: %[[UPDATE1_1:.*]] = spv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
// CHECK: %[[OFFSET1_1:.*]] = spv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
// CHECK: %[[STRIDE1_2:.*]] = spv.constant 1 : i32
// CHECK: %[[UPDATE1_2:.*]] = spv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
// CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32
// CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO1]], %[[OFFSET1_2]]{{\]}}
// CHECK: %[[PTR1:.*]] = spv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spv.Load "StorageBuffer" %[[PTR1]]
%14 = load %arg0[%12, %13] : memref<12x4xf32>
// CHECK: %[[PTR2:.*]] = spv.AccessChain %[[ARG1]]{{\[}}{{%.*}}, {{%.*}}{{\]}}
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Conversion/GPUToSPIRV/loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ module attributes {
// CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR]], %[[UB]] : i32
// CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]]
// CHECK: ^[[BODY]]:
// CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32
// CHECK: %[[INDEX1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
// CHECK: %[[ZERO1:.*]] = spv.constant 0 : i32
// CHECK: %[[OFFSET1:.*]] = spv.constant 0 : i32
// CHECK: %[[STRIDE1:.*]] = spv.constant 1 : i32
// CHECK: %[[UPDATE1:.*]] = spv.IMul %[[STRIDE1]], %[[INDVAR]] : i32
// CHECK: %[[INDEX1:.*]] = spv.IAdd %[[OFFSET1]], %[[UPDATE1]] : i32
// CHECK: spv.AccessChain {{%.*}}{{\[}}%[[ZERO1]], %[[INDEX1]]{{\]}}
// CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32
// CHECK: %[[INDEX2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
// CHECK: %[[ZERO2:.*]] = spv.constant 0 : i32
// CHECK: %[[OFFSET2:.*]] = spv.constant 0 : i32
// CHECK: %[[STRIDE2:.*]] = spv.constant 1 : i32
// CHECK: %[[UPDATE2:.*]] = spv.IMul %[[STRIDE2]], %[[INDVAR]] : i32
// CHECK: %[[INDEX2:.*]] = spv.IAdd %[[OFFSET2]], %[[UPDATE2]] : i32
// CHECK: spv.AccessChain {{%.*}}[%[[ZERO2]], %[[INDEX2]]]
// CHECK: %[[INCREMENT:.*]] = spv.IAdd %[[INDVAR]], %[[STEP]] : i32
// CHECK: spv.Branch ^[[HEADER]](%[[INCREMENT]] : i32)
Expand Down
Loading

0 comments on commit 4d6f44f

Please sign in to comment.