Skip to content

Commit

Permalink
Add conversions of GPU func with memory attributions to LLVM/NVVM
Browse files Browse the repository at this point in the history
GPU functions use memory attributions, a combination of Op attributes and
region arguments, to specify function-wide buffers placed in workgroup or
private memory spaces. Introduce a lowering pattern for GPU functions to be
converted to LLVM functions taking into account memory attributions. Workgroup
attributions get transformed into module-level globals with unique names
derived from function names. Private attributions get converted into
llvm.allocas inside the function body. In both cases, we inject at the
beginning of the function the IR that obtains the raw pointer to the data and
populates a MemRef descriptor based on the MemRef type of buffer, making
attributions compose with the rest of the MemRef lowering and transparent for
use with std.load and std.store. While using raw pointers instead of
descriptors might have been more efficient, it is better implemented as a
canonicalization or a separate transformation so that non-attribution memrefs
could also benefit from it.

PiperOrigin-RevId: 284208396
  • Loading branch information
ftynse authored and tensorflower-gardener committed Dec 6, 2019
1 parent 3c69ca1 commit e216a72
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 9 deletions.
Expand Up @@ -168,6 +168,13 @@ class MemRefDescriptor : public StructBuilder {
/// Builds IR creating an `undef` value of the descriptor type.
static MemRefDescriptor undef(OpBuilder &builder, Location loc,
Type descriptorType);
/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value *memory);

/// Builds IR extracting the allocated pointer from the descriptor.
Value *allocatedPtr(OpBuilder &builder, Location loc);
/// Builds IR inserting the allocated pointer into the descriptor.
Expand All @@ -184,18 +191,23 @@ class MemRefDescriptor : public StructBuilder {

/// Builds IR inserting the offset into the descriptor.
void setOffset(OpBuilder &builder, Location loc, Value *offset);
void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset);

/// Builds IR extracting the pos-th size from the descriptor.
Value *size(OpBuilder &builder, Location loc, unsigned pos);

/// Builds IR inserting the pos-th size into the descriptor
void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size);
void setConstantSize(OpBuilder &builder, Location loc, unsigned pos,
uint64_t size);

/// Builds IR extracting the pos-th size from the descriptor.
Value *stride(OpBuilder &builder, Location loc, unsigned pos);

/// Builds IR inserting the pos-th stride into the descriptor
void setStride(OpBuilder &builder, Location loc, unsigned pos, Value *stride);
void setConstantStride(OpBuilder &builder, Location loc, unsigned pos,
uint64_t stride);

/// Returns the (LLVM) type this descriptor points to.
LLVM::LLVMType getElementType();
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Dialect/GPU/GPUDialect.h
Expand Up @@ -61,6 +61,10 @@ class GPUDialect : public Dialect {
/// 'gpu.kernel' attribute.
static bool isKernel(Operation *op);

/// Returns the numeric value used to identify the workgroup memory address
/// space.
static int getWorkgroupAddressSpace() { return 3; }

LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute attr) override;
};
Expand Down Expand Up @@ -249,6 +253,12 @@ class GPUFuncOp : public Op<GPUFuncOp, OpTrait::FunctionLike,
return {begin, getBody().front().args_end()};
}

/// Returns the name of the attribute containing the number of buffers located
/// in the workgroup memory.
static StringRef getNumWorkgroupAttributionsAttrName() {
return "workgroup_attibutions";
}

private:
// FunctionLike trait needs access to the functions below.
friend class OpTrait::FunctionLike<GPUFuncOp>;
Expand All @@ -257,12 +267,6 @@ class GPUFuncOp : public Op<GPUFuncOp, OpTrait::FunctionLike,
unsigned getNumFuncArguments() { return getType().getNumInputs(); }
unsigned getNumFuncResults() { return getType().getNumResults(); }

/// Returns the name of the attribute containing the number of buffers located
/// in the workgroup memory.
static StringRef getNumWorkgroupAttributionsAttrName() {
return "workgroup_attibutions";
}

/// Returns the keywords used in the custom syntax for this Op.
static StringRef getWorkgroupKeyword() { return "workgroup"; }
static StringRef getPrivateKeyword() { return "private"; }
Expand Down
145 changes: 144 additions & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Expand Up @@ -29,6 +29,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

#include "llvm/Support/FormatVariadic.h"

#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"

Expand Down Expand Up @@ -451,6 +453,146 @@ struct GPUAllReduceOpLowering : public LLVMOpLowering {
static constexpr int kWarpSize = 32;
};

namespace {

struct FuncOpLowering : LLVMOpLowering {
explicit FuncOpLowering(LLVMTypeConverter &typeConverter)
: LLVMOpLowering(gpu::GPUFuncOp::getOperationName(),
typeConverter.getDialect()->getContext(),
typeConverter) {}

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
assert(operands.empty() && "func op is not expected to have operands");
auto gpuFuncOp = cast<gpu::GPUFuncOp>(op);
Location loc = gpuFuncOp.getLoc();

SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
Value *attribution = en.value();

auto type = attribution->getType().dyn_cast<MemRefType>();
assert(type && type.hasStaticShape() && "unexpected type in attribution");

uint64_t numElements = type.getNumElements();

auto elementType =
lowering.convertType(type.getElementType()).cast<LLVM::LLVMType>();
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements);
auto addSpaceAttr = rewriter.getNamedAttr(
"addr_space", rewriter.getI32IntegerAttr(
gpu::GPUDialect::getWorkgroupAddressSpace()));
std::string name =
llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index());
auto globalOp = rewriter.create<LLVM::GlobalOp>(
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(),
llvm::makeArrayRef(addSpaceAttr));
workgroupBuffers.push_back(globalOp);
}

// Rewrite the original GPU function to an LLVM function.
// TODO(zinenko): there is a hack in the std->llvm lowering that promotes
// structs to pointers that probably needs to be replicated here.
auto funcType = lowering.convertType(gpuFuncOp.getType())
.cast<LLVM::LLVMType>()
.getPointerElementTy();

// Remap proper input types.
TypeConverter::SignatureConversion signatureConversion(
gpuFuncOp.front().getNumArguments());
for (unsigned i = 0, e = funcType.getFunctionNumParams(); i < e; ++i)
signatureConversion.addInputs(i, funcType.getFunctionParamType(i));

// Create the new function operation. Only copy those attributes that are
// not specific to function modeling.
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : gpuFuncOp.getAttrs()) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is(gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()))
continue;
attributes.push_back(attr);
}
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, attributes);

{
// Insert operations that correspond to converted workgroup and private
// memory attributions to the body of the function. This must operate on
// the original function, before the body region is inlined in the new
// function to maintain the relation between block arguments and the
// parent operation that assigns their semantics.
OpBuilder::InsertionGuard guard(rewriter);

// Rewrite workgroup memory attributions to addresses of global buffers.
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
auto i32Type = LLVM::LLVMType::getInt32Ty(lowering.getDialect());

Value *zero = nullptr;
if (!workgroupBuffers.empty())
zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
rewriter.getI32IntegerAttr(0));
for (auto en : llvm::enumerate(workgroupBuffers)) {
LLVM::GlobalOp global = en.value();
Value *address = rewriter.create<LLVM::AddressOfOp>(loc, global);
auto elementType = global.getType().getArrayElementType();
Value *memory = rewriter.create<LLVM::GEPOp>(
loc, elementType.getPointerTo(global.addr_space().getZExtValue()),
address, ArrayRef<Value *>{zero, zero});

// Build a memref descriptor pointing to the buffer to plug with the
// existing memref infrastructure. This may use more registers than
// otherwise necessary given that memref sizes are fixed, but we can try
// and canonicalize that away later.
Value *attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
auto type = attribution->getType().cast<MemRefType>();
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, memory);
signatureConversion.remapInput(numProperArguments + en.index(), descr);
}

// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions =
gpuFuncOp.getNumWorkgroupAttributions();
auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect());
for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
Value *attribution = en.value();
auto type = attribution->getType().cast<MemRefType>();
assert(type && type.hasStaticShape() &&
"unexpected type in attribution");

auto ptrType = lowering.convertType(type.getElementType())
.cast<LLVM::LLVMType>()
.getPointerTo(type.getMemorySpace());
Value *numElements = rewriter.create<LLVM::ConstantOp>(
gpuFuncOp.getLoc(), int64Ty,
rewriter.getI64IntegerAttr(type.getNumElements()));
Value *allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
auto descr = MemRefDescriptor::fromStaticShape(rewriter, loc, lowering,
type, allocated);
signatureConversion.remapInput(
numProperArguments + numWorkgroupAttributions + en.index(), descr);
}
}

rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
llvmFuncOp.end());
rewriter.applySignatureConversion(&llvmFuncOp.getBody(),
signatureConversion);

rewriter.eraseOp(gpuFuncOp);
return matchSuccess();
}
};

} // end namespace

/// Import the GPU Ops to NVVM Patterns.
#include "GPUToNVVM.cpp.inc"

Expand Down Expand Up @@ -479,12 +621,13 @@ class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
GPUAllReduceOpLowering>(converter);
GPUAllReduceOpLowering, FuncOpLowering>(converter);
patterns.insert<OpToFuncCallLowering<ExpOp>>(converter, "__nv_expf",
"__nv_exp");
ConversionTarget target(getContext());
target.addIllegalDialect<gpu::GPUDialect>();
target.addIllegalOp<LLVM::ExpOp>();
target.addIllegalOp<FuncOp>();
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
// TODO(csigg): Remove once we support replacing non-root ops.
Expand Down
62 changes: 60 additions & 2 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -304,6 +304,36 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
return MemRefDescriptor(descriptor);
}

/// Builds IR creating a MemRef descriptor that represents `type` and
/// populates it with static shape and stride information extracted from the
/// type.
MemRefDescriptor
MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
LLVMTypeConverter &typeConverter,
MemRefType type, Value *memory) {
assert(type.hasStaticShape() && "unexpected dynamic shape");
assert(type.getAffineMaps().empty() && "unexpected layout map");

auto convertedType = typeConverter.convertType(type);
assert(convertedType && "unexpected failure in memref type conversion");

auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
descr.setAllocatedPtr(builder, loc, memory);
descr.setAlignedPtr(builder, loc, memory);
descr.setConstantOffset(builder, loc, 0);

// Fill in sizes and strides, in reverse order to simplify stride
// calculation.
uint64_t runningStride = 1;
for (unsigned i = type.getRank(); i > 0; --i) {
unsigned dim = i - 1;
descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
descr.setConstantStride(builder, loc, dim, runningStride);
runningStride *= type.getDimSize(dim);
}
return descr;
}

/// Builds IR extracting the allocated pointer from the descriptor.
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
Expand All @@ -326,6 +356,14 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}

// Creates a constant Op producing a value of `resultType` from an index-typed
// integer attribute.
static Value *createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value) {
return builder.create<LLVM::ConstantOp>(
loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
}

/// Builds IR extracting the offset from the descriptor.
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
Expand All @@ -341,6 +379,13 @@ void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}

/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
uint64_t offset) {
setOffset(builder, loc,
createIndexAttrConstant(builder, loc, indexType, offset));
}

/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
Expand All @@ -356,6 +401,13 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}

/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
unsigned pos, uint64_t size) {
setSize(builder, loc, pos,
createIndexAttrConstant(builder, loc, indexType, size));
}

/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
unsigned pos) {
Expand All @@ -372,6 +424,13 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}

/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
unsigned pos, uint64_t stride) {
setStride(builder, loc, pos,
createIndexAttrConstant(builder, loc, indexType, stride));
}

LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor);
Expand Down Expand Up @@ -448,8 +507,7 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
// Create an LLVM IR pseudo-operation defining the given index constant.
Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
return createIndexAttrConstant(builder, loc, getIndexType(), value);
}

protected:
Expand Down

0 comments on commit e216a72

Please sign in to comment.