Skip to content

Commit

Permalink
Retain address space during MLIR > LLVM conversion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 267206460
  • Loading branch information
MLIR Team authored and tensorflower-gardener committed Sep 4, 2019
1 parent c6f8ada commit b565272
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
21 changes: 9 additions & 12 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -124,14 +124,14 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {

// Convert a MemRef to an LLVM type. If the memref is statically-shaped, then
// we return a pointer to the converted element type. Otherwise we return an
// LLVM stucture type, where the first element of the structure type is a
// LLVM structure type, where the first element of the structure type is a
// pointer to the elemental type of the MemRef and the following N elements are
// values of the Index type, one for each of N dynamic dimensions of the MemRef.
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrType = elementType.getPointerTo();
auto ptrType = elementType.getPointerTo(type.getMemorySpace());

// Extra value for the memory space.
unsigned numDynamicSizes = type.getNumDynamicDims();
Expand Down Expand Up @@ -189,7 +189,7 @@ static Type getMemRefElementPtrType(MemRefType t, LLVMTypeConverter &lowering) {
auto converted = lowering.convertType(elementType);
if (!converted)
return {};
return converted.cast<LLVM::LLVMType>().getPointerTo();
return converted.cast<LLVM::LLVMType>().getPointerTo(t.getMemorySpace());
}

LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
Expand Down Expand Up @@ -514,14 +514,11 @@ struct ConstLLVMOpLowering
using Super::Super;
};

// Check if the MemRefType `type` is supported by the lowering. We currently do
// not support memrefs with affine maps and non-default memory spaces.
// Check if the MemRefType `type` is supported by the lowering. We currently
// only support memrefs with identity maps.
static bool isSupportedMemRefType(MemRefType type) {
if (!type.getAffineMaps().empty())
return false;
if (type.getMemorySpace() != 0)
return false;
return true;
return llvm::all_of(type.getAffineMaps(),
[](AffineMap map) { return map.isIdentity(); });
}

// An `alloc` is converted into a definition of a memref descriptor value and
Expand Down Expand Up @@ -598,8 +595,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
cumulativeSize)
.getResult(0);
auto structElementType = lowering.convertType(elementType);
auto elementPtrType =
structElementType.cast<LLVM::LLVMType>().getPointerTo();
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
ArrayRef<Value *>(allocated));

Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
@@ -0,0 +1,11 @@
// RUN: mlir-opt %s -lower-to-llvm | FileCheck %s

// CHECK-LABEL: func @address_space(%{{.*}}: !llvm<"float addrspace(7)*">)
func @address_space(%arg0 : memref<32xf32, (d0) -> (d0), 7>) {
%0 = alloc() : memref<32xf32, (d0) -> (d0), 5>
%1 = constant 7 : index
// CHECK: llvm.load %{{.*}} : !llvm<"float addrspace(5)*">
%2 = load %0[%1] : memref<32xf32, (d0) -> (d0), 5>
std.return
}

0 comments on commit b565272

Please sign in to comment.