Skip to content

Commit

Permalink
[MLIR][DataLayout] Add support for scalable vectors (#89349)
Browse files Browse the repository at this point in the history
This commit extends the data layout to support scalable vectors. For
scalable vectors, the `TypeSize`'s scalable field is set accordingly,
and the alignment information remains the same as for normal vectors.
This behavior is in sync with what LLVM's data layout queries are
producing.

Before this change, scalable vectors incorrectly returned the same size
as "normal" vectors.
  • Loading branch information
Dinistro authored Apr 19, 2024
1 parent 4d7f3d9 commit df411fb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
15 changes: 9 additions & 6 deletions mlir/lib/Interfaces/DataLayoutInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
// there is no bit-packing at the moment element sizes are taken in bytes and
// multiplied with 8 bits.
// TODO: make this extensible.
if (auto vecType = dyn_cast<VectorType>(type))
return vecType.getNumElements() / vecType.getShape().back() *
llvm::PowerOf2Ceil(vecType.getShape().back()) *
dataLayout.getTypeSize(vecType.getElementType()) * 8;
if (auto vecType = dyn_cast<VectorType>(type)) {
uint64_t baseSize = vecType.getNumElements() / vecType.getShape().back() *
llvm::PowerOf2Ceil(vecType.getShape().back()) *
dataLayout.getTypeSize(vecType.getElementType()) * 8;
return llvm::TypeSize::get(baseSize, vecType.isScalable());
}

if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
return typeInterface.getTypeSizeInBits(dataLayout, params);
Expand Down Expand Up @@ -138,9 +140,10 @@ getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
uint64_t mlir::detail::getDefaultABIAlignment(
Type type, const DataLayout &dataLayout,
ArrayRef<DataLayoutEntryInterface> params) {
// Natural alignment is the closest power-of-two number above.
// Natural alignment is the closest power-of-two number above. For scalable
// vectors, aligning them to the same as the base vector is sufficient.
if (isa<VectorType>(type))
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type).getKnownMinValue());

if (auto fltType = dyn_cast<FloatType>(type))
return getFloatTypeABIAlignment(fltType, dataLayout, params);
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Interfaces/DataLayoutInterfaces/query.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ func.func @no_layout_builtin() {
// CHECK: preferred = 8
// CHECK: size = 8
"test.data_layout_query"() : () -> index
// CHECK: alignment = 16
// CHECK: bitsize = 128
// CHECK: index = 0
// CHECK: preferred = 16
// CHECK: size = 16
"test.data_layout_query"() : () -> vector<4xi32>
// CHECK: alignment = 16
// CHECK: bitsize = {minimal_size = 128 : index, scalable}
// CHECK: index = 0
// CHECK: preferred = 16
// CHECK: size = {minimal_size = 16 : index, scalable}
"test.data_layout_query"() : () -> vector<[4]xi32>
return

}
Expand Down
17 changes: 15 additions & 2 deletions mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,22 @@ struct TestDataLayoutQuery
Attribute programMemorySpace = layout.getProgramMemorySpace();
Attribute globalMemorySpace = layout.getGlobalMemorySpace();
uint64_t stackAlignment = layout.getStackAlignment();

auto convertTypeSizeToAttr = [&](llvm::TypeSize typeSize) -> Attribute {
if (!typeSize.isScalable())
return builder.getIndexAttr(typeSize);

return builder.getDictionaryAttr({
builder.getNamedAttr("scalable", builder.getUnitAttr()),
builder.getNamedAttr(
"minimal_size",
builder.getIndexAttr(typeSize.getKnownMinValue())),
});
};

op->setAttrs(
{builder.getNamedAttr("size", builder.getIndexAttr(size)),
builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)),
{builder.getNamedAttr("size", convertTypeSizeToAttr(size)),
builder.getNamedAttr("bitsize", convertTypeSizeToAttr(bitsize)),
builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
builder.getNamedAttr("index", builder.getIndexAttr(index)),
Expand Down

0 comments on commit df411fb

Please sign in to comment.