Skip to content

Commit

Permalink
Make positions of elements in MemRef descriptor private
Browse files Browse the repository at this point in the history
Previous commits removed all uses of LLVMTypeConverter::k*PosInMemRefDescriptor
outside of the MemRefDescriptor class. These numbers are an implementation
detail and can be hidden under a layer of more semantic APIs.

PiperOrigin-RevId: 280442444
  • Loading branch information
ftynse authored and tensorflower-gardener committed Nov 14, 2019
1 parent bf5916e commit b34a861
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 31 deletions.
Expand Up @@ -82,12 +82,6 @@ class LLVMTypeConverter : public TypeConverter {
Value *promoteOneMemRefDescriptor(Location loc, Value *operand,
OpBuilder &builder);

static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
static constexpr unsigned kStridePosInMemRefDescriptor = 4;

protected:
/// LLVM IR module used to parse/create types.
llvm::Module *module;
Expand Down
42 changes: 17 additions & 25 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -157,11 +157,11 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
constexpr unsigned LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kOffsetPosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kSizePosInMemRefDescriptor;
constexpr unsigned LLVMTypeConverter::kStridePosInMemRefDescriptor;
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0;
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1;
static constexpr unsigned kOffsetPosInMemRefDescriptor = 2;
static constexpr unsigned kSizePosInMemRefDescriptor = 3;
static constexpr unsigned kStridePosInMemRefDescriptor = 4;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
Expand Down Expand Up @@ -243,7 +243,7 @@ MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) {
if (value) {
structType = value->getType().cast<LLVM::LLVMType>();
indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType(
LLVMTypeConverter::kOffsetPosInMemRefDescriptor);
kOffsetPosInMemRefDescriptor);
}
}

Expand All @@ -257,78 +257,70 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,

/// Builds IR extracting the allocated pointer from the descriptor.
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor);
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
}

/// Builds IR inserting the allocated pointer into the descriptor.
void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor,
ptr);
setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr);
}

/// Builds IR extracting the aligned pointer from the descriptor.
Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc,
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor);
}

/// Builds IR inserting the aligned pointer into the descriptor.
void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
Value *ptr) {
setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor,
ptr);
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}

/// Builds IR extracting the offset from the descriptor.
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}

/// Builds IR inserting the offset into the descriptor.
void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
Value *offset) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, offset,
builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}

/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}

/// Builds IR inserting the pos-th size into the descriptor
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
Value *size) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, size,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kSizePosInMemRefDescriptor, pos}));
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}

/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, indexType, value,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}

/// Builds IR inserting the pos-th stride into the descriptor
void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
Value *stride) {
value = builder.create<LLVM::InsertValueOp>(
loc, structType, value, stride,
builder.getI64ArrayAttr(
{LLVMTypeConverter::kStridePosInMemRefDescriptor, pos}));
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}

Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc,
Expand All @@ -346,7 +338,7 @@ void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos,

LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor);
kAlignedPtrPosInMemRefDescriptor);
}

namespace {
Expand Down

0 comments on commit b34a861

Please sign in to comment.