Skip to content

Commit

Permalink
Refactor linalg lowering to LLVM
Browse files Browse the repository at this point in the history
The linalg.view type used to be lowered to a struct containing a data pointer, offset, sizes/strides information. This was problematic when passing to external functions due to ABI, struct padding and alignment issues.

The linalg.view type is now lowered to LLVMIR as a *pointer* to a struct containing the data pointer, offset and sizes/strides. This simplifies the interfacing with external library functions and makes it trivial to add new functions without creating a shim that would go from a value type struct to a pointer type.

The consequences are that:
1. lowering explicitly uses llvm.alloca in lieu of llvm.undef and performs the proper llvm.load/llvm.store where relevant.
2. the shim creation function `getLLVMLibraryCallDefinition` disappears.
3. views are passed by pointer, scalars are passed by value. In the future, other structs will be passed by pointer (on a per-need basis).

PiperOrigin-RevId: 264183671
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Aug 19, 2019
1 parent c9f37fc commit 9bf69e6
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 223 deletions.
22 changes: 16 additions & 6 deletions mlir/include/mlir/LLVMIR/LLVMOps.td
Expand Up @@ -234,14 +234,12 @@ def LLVM_AllocaOp :
$res = alloca;
}];
let builders = [OpBuilder<
"Builder *b, OperationState *result, Type resultType, Value *arraySize,"
"unsigned alignment = 0",
"Builder *b, OperationState *result, Type resultType, Value *arraySize, "
"unsigned alignment",
[{
if (!alignment)
if (alignment == 0)
return build(b, result, resultType, arraySize, IntegerAttr());
auto *ctx = resultType.getContext();
auto align = IntegerAttr::get(IntegerType::get(64, ctx), alignment);
build(b, result, resultType, arraySize, align);
build(b, result, resultType, arraySize, b->getI64IntegerAttr(alignment));
}]>];
let parser = [{ return parseAllocaOp(parser, result); }];
let printer = [{ printAllocaOp(p, *this); }];
Expand All @@ -262,6 +260,12 @@ def LLVM_GEPOp : LLVM_OneResultOp<"getelementptr", [NoSideEffect]>,
}
def LLVM_LoadOp : LLVM_OneResultOp<"load">, Arguments<(ins LLVM_Type:$addr)>,
LLVM_Builder<"$res = builder.CreateLoad($addr);"> {
let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *addr",
[{
auto type = addr->getType().cast<LLVM::LLVMType>().getPointerElementTy();
build(b, result, type, addr);
}]>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
}
Expand Down Expand Up @@ -344,6 +348,12 @@ def LLVM_InsertValueOp : LLVM_OneResultOp<"insertvalue", [NoSideEffect]>,
$res = builder.CreateInsertValue($container, $value,
extractPosition($position));
}];
let builders = [OpBuilder<
"Builder *b, OperationState *result, Value *container, Value *value, "
"ArrayAttr position",
[{
build(b, result, container->getType(), container, value, position);
}]>];
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
}
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Linalg/IR/LinalgBase.td
Expand Up @@ -30,6 +30,21 @@ include "mlir/IR/OpBase.td"

def Linalg_Dialect : Dialect {
let name = "linalg";
let description = [{
The Linalg dialect groups together a set of types and operations that are
useful to implement a "linear algebra"-like abstraction where ops can lower
to scalar load/store and operations or to more general library calls.

The Linalg dialect adopts a convention that is similar to BLAS when
offloading operations to fast library implementations: pass a non-owning
pointer to input and output data with additional metadata. This convention
is also found in libraries such as MKL, OpenBLAS, cuBLAS, cuDNN, etc.. and
more generally at interface points across language boundaries (e.g. C++ /
Python).

Generally, Linalg passes non-owning pointers to View data structures to
precompiled library calls linked externally.
}];
}

// Whether a type is a BufferType.
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
Expand Up @@ -111,7 +111,8 @@ class GpuLaunchFuncToCudaCallsPass
Value *allocatePointer(OpBuilder &builder, Location loc) {
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(1));
return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one);
return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one,
/*alignment=*/0);
}

void declareCudaFunctions(Location loc);
Expand Down Expand Up @@ -233,13 +234,13 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(launchOp.getNumKernelOperands()));
auto array =
builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), arraySize);
auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
arraySize, /*alignment=*/0);
for (int idx = 0, e = launchOp.getNumKernelOperands(); idx < e; ++idx) {
auto operand = launchOp.getKernelOperand(idx);
auto llvmType = operand->getType().cast<LLVM::LLVMType>();
auto memLocation =
builder.create<LLVM::AllocaOp>(loc, llvmType.getPointerTo(), one);
auto memLocation = builder.create<LLVM::AllocaOp>(
loc, llvmType.getPointerTo(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
auto casted =
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
Expand Down Expand Up @@ -267,8 +268,8 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
auto kernelName =
builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
auto kernelName = builder.create<LLVM::AllocaOp>(
loc, getPointerType(), kernelNameSize, /*alignment=*/1);
for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -808,10 +808,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {

Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
auto elementType = lowering.convertType(type.getElementType());

rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
ArrayRef<Value *>{dataPtr});
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return matchSuccess();
}
};
Expand Down

0 comments on commit 9bf69e6

Please sign in to comment.