Skip to content

Commit

Permalink
[mlir][LLVM] Make SplitStores pattern capable of writing to sub-agg…
Browse files Browse the repository at this point in the history
…regates

The pattern was previously only capable of storing into struct fields which are primitive types. If the struct contained a nested struct it immediately aborted the pattern rewrite.

This patch introduces the capability of recursively splitting stores into sub-structs as well. This is achieved by splitting an aggregate sized integer from the original store argument and letting repeated pattern applications further split it into field stores.

Additionally, the pattern is also capable of handling partial writes into aggregates, which is a pattern clang may generate as well. Special care had to be taken to make sure no stores are created that weren't in the original code.

Differential Revision: https://reviews.llvm.org/D154707
  • Loading branch information
zero9178 committed Jul 10, 2023
1 parent 63ca93c commit 7786449
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 24 deletions.
68 changes: 44 additions & 24 deletions mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep,
/// types, failure is returned.
static FailureOr<ArrayRef<Type>>
getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
int storeSize, unsigned storeOffset) {
unsigned storeSize, unsigned storeOffset) {
ArrayRef<Type> body = structType.getBody();
unsigned currentOffset = 0;
body = body.drop_until([&](Type type) {
Expand All @@ -381,10 +381,6 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,

size_t exclusiveEnd = 0;
for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) {
// Not yet recursively handling aggregates, only primitives.
if (!isa<IntegerType, FloatType>(body[exclusiveEnd]))
return failure();

if (!structType.isPacked()) {
unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]);
// No padding allowed inbetween fields at this point in time.
Expand All @@ -393,13 +389,29 @@ getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType,
}

unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]);
if (fieldSize > storeSize) {
// Partial writes into an aggregate are okay since subsequent pattern
// applications can further split these up into writes into the
// sub-elements.
auto subStruct = dyn_cast<LLVMStructType>(body[exclusiveEnd]);
if (!subStruct)
return failure();

// Avoid splitting redundantly by making sure the store into the struct
// can actually be split.
if (failed(getWrittenToFields(dataLayout, subStruct, storeSize,
/*storeOffset=*/0)))
return failure();

return body.take_front(exclusiveEnd + 1);
}
currentOffset += fieldSize;
storeSize -= fieldSize;
}

// If the storeSize is not 0 at this point we are either partially writing
// into a field or writing past the aggregate as a whole. Abort.
if (storeSize != 0)
// If the storeSize is not 0 at this point we are writing past the aggregate
// as a whole. Abort.
if (storeSize > 0)
return failure();
return body.take_front(exclusiveEnd);
}
Expand Down Expand Up @@ -435,7 +447,8 @@ static void splitVectorStore(const DataLayout &dataLayout, Location loc,
/// type-consistent.
static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
RewriterBase &rewriter, Value address,
Value value, unsigned storeOffset,
Value value, unsigned storeSize,
unsigned storeOffset,
ArrayRef<Type> writtenToFields) {
unsigned currentOffset = storeOffset;
for (Type type : writtenToFields) {
Expand All @@ -449,7 +462,12 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,

auto shrOp = rewriter.create<LShrOp>(loc, value, pos);

IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8);
// If we are doing a partial write into a direct field the remaining
// `storeSize` will be less than the size of the field. We have to truncate
// to the `storeSize` to avoid creating a store that wasn't in the original
// code.
IntegerType fieldIntType =
rewriter.getIntegerType(std::min(fieldSize, storeSize) * 8);
Value valueToStore = rewriter.create<TruncOp>(loc, fieldIntType, shrOp);

// We create an `i8` indexed GEP here as that is the easiest (offset is
Expand All @@ -462,6 +480,7 @@ static void splitIntegerStore(const DataLayout &dataLayout, Location loc,
// No need to care about padding here since we already checked previously
// that no padding exists in this range.
currentOffset += fieldSize;
storeSize -= fieldSize;
}
}

Expand All @@ -481,28 +500,31 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,

auto dataLayout = DataLayout::closest(store);

unsigned storeSize = dataLayout.getTypeSize(sourceType);
unsigned offset = 0;
Value address = store.getAddr();
if (auto gepOp = address.getDefiningOp<GEPOp>()) {
// Currently only handle canonical GEPs with exactly two indices,
// indexing a single aggregate deep.
// Recursing into sub-structs is left as a future exercise.
// If the GEP is not canonical we have to fail, otherwise we would not
// create type-consistent IR.
if (gepOp.getIndices().size() != 2 ||
succeeded(getRequiredConsistentGEPType(gepOp)))
return failure();

// A GEP might point somewhere into the middle of an aggregate with the
// store storing into multiple adjacent elements. Destructure into
// the base address with an offset.
std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
if (!byteOffset)
return failure();
// If the size of the element indexed by the GEP is smaller than the store
// size, it is pointing into the middle of an aggregate with the store
// storing into multiple adjacent elements. Destructure into the base
// address of the aggregate with a store offset.
if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) {
std::optional<uint64_t> byteOffset = gepToByteOffset(dataLayout, gepOp);
if (!byteOffset)
return failure();

offset = *byteOffset;
typeHint = gepOp.getSourceElementType();
address = gepOp.getBase();
offset = *byteOffset;
typeHint = gepOp.getSourceElementType();
address = gepOp.getBase();
}
}

auto structType = typeHint.dyn_cast<LLVMStructType>();
Expand All @@ -512,9 +534,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,
}

FailureOr<ArrayRef<Type>> writtenToFields =
getWrittenToFields(dataLayout, structType,
/*storeSize=*/dataLayout.getTypeSize(sourceType),
/*storeOffset=*/offset);
getWrittenToFields(dataLayout, structType, storeSize, offset);
if (failed(writtenToFields))
return failure();

Expand All @@ -526,7 +546,7 @@ LogicalResult SplitStores::matchAndRewrite(StoreOp store,

if (isa<IntegerType>(sourceType)) {
splitIntegerStore(dataLayout, store.getLoc(), rewriter, address,
store.getValue(), offset, *writtenToFields);
store.getValue(), storeSize, offset, *writtenToFields);
rewriter.eraseOp(store);
return success();
}
Expand Down
131 changes: 131 additions & 0 deletions mlir/test/Dialect/LLVMIR/type-consistency.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,134 @@ llvm.func @gep_result_ptr_type_dynamic(%arg: i64) {
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @overlapping_int_aggregate_store
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @overlapping_int_aggregate_store(%arg: i64) {
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64

%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr

// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
// CHECK: llvm.store %[[TRUNC]], %[[GEP]]

// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
// CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
// CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>

// Normal integer splitting of [[TRUNC]] follows:

// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]

llvm.store %arg, %1 : i64, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @overlapping_vector_aggregate_store
// CHECK-SAME: %[[ARG:.*]]: vector<4xi16>
llvm.func @overlapping_vector_aggregate_store(%arg: vector<4 x i16>) {
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32

%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr

// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32]
// CHECK: llvm.store %[[EXTRACT]], %[[GEP]]

// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32]
// CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]

// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32]
// CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]

// CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32]
// CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)>
// CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)>
// CHECK: llvm.store %[[EXTRACT]], %[[GEP1]]

llvm.store %arg, %1 : vector<4 x i16>, !llvm.ptr
// CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]]
llvm.return
}

// -----

// CHECK-LABEL: llvm.func @partially_overlapping_aggregate_store
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @partially_overlapping_aggregate_store(%arg: i64) {
// CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64

%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> : (i32) -> !llvm.ptr

// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>
// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]]
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16
// CHECK: llvm.store %[[TRUNC]], %[[GEP]]

// CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64
// CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48
// CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)>

// Normal integer splitting of [[TRUNC]] follows:

// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)>
// CHECK: llvm.store %{{.*}}, %[[GEP]]

// It is important that there are no more stores at this point.
// Specifically a store into the fourth field of %[[TOP_GEP]] would
// incorrectly change the semantics of the code.
// CHECK-NOT: llvm.store %{{.*}}, %{{.*}}

llvm.store %arg, %1 : i64, !llvm.ptr

llvm.return
}

// -----

// Here a split is undesirable since the store does a partial store into the field.

// CHECK-LABEL: llvm.func @undesirable_overlapping_aggregate_store
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @undesirable_overlapping_aggregate_store(%arg: i64) {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> : (i32) -> !llvm.ptr
// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
%2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)>
// CHECK: llvm.store %[[ARG]], %[[GEP]]
llvm.store %arg, %2 : i64, !llvm.ptr

llvm.return
}

0 comments on commit 7786449

Please sign in to comment.