Skip to content

Commit

Permalink
[mlir] speed up construction of LLVM IR constants when possible
Browse files Browse the repository at this point in the history
The translation to LLVM IR used to construct sequential constants by recurring
down to individual elements, creating constant values for them, and wrapping
them into aggregate constants in post-order. This is highly inefficient for
large constants with known data such as DenseElementsAttr. Use LLVM's
ConstantData for the innermost dimension instead. LLVM does seem to support
data constants for nested sequential constants so the outer dimensions are
still handled recursively. Nevertheless, this speeds up the translation of
large constants with equal dimensions by up to 30x.

Users are advised to rewrite large constants to use flat types before
translating to LLVM IR if more efficiency in translation is necessary. This is
not done automatically as the translation is not aware of the expectations of
the overall compilation flow about type changes and indexing, in particular for
global constants with external linkage.

Reviewed By: silvas

Differential Revision: https://reviews.llvm.org/D109152
  • Loading branch information
ftynse committed Sep 2, 2021
1 parent 00f8aec commit f9be7a7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
94 changes: 94 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Expand Up @@ -101,6 +101,92 @@ static llvm::Type *getInnermostElementType(llvm::Type *type) {
} while (true);
}

/// Convert a dense elements attribute to an LLVM IR constant using its raw data
/// storage if possible. This supports elements attributes of tensor or vector
/// type and avoids constructing separate objects for individual values of the
/// innermost dimension. Constants for other dimensions are still constructed
/// recursively. Returns null if constructing from raw data is not supported for
/// this type, e.g., element type is not a power-of-two-sized primitive. Reports
/// other errors at `loc`.
static llvm::Constant *
convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
llvm::Type *llvmType,
const ModuleTranslation &moduleTranslation) {
if (!denseElementsAttr)
return nullptr;

llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
if (!llvm::ConstantDataSequential::isElementTypeCompatible(innermostLLVMType))
return nullptr;

// Compute the shape of all dimensions but the innermost. Note that the
// innermost dimension may be that of the vector element type.
ShapedType type = denseElementsAttr.getType();
bool hasVectorElementType = type.getElementType().isa<VectorType>();
unsigned numAggregates =
denseElementsAttr.getNumElements() /
(hasVectorElementType ? 1
: denseElementsAttr.getType().getShape().back());
ArrayRef<int64_t> outerShape = type.getShape();
if (!hasVectorElementType)
outerShape = outerShape.drop_back();

// Handle the case of vector splat, LLVM has special support for it.
if (denseElementsAttr.isSplat() &&
(type.isa<VectorType>() || hasVectorElementType)) {
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
moduleTranslation, /*isTopLevel=*/false);
llvm::Constant *splatVector =
llvm::ConstantDataVector::getSplat(0, splatValue);
SmallVector<llvm::Constant *> constants(numAggregates, splatVector);
ArrayRef<llvm::Constant *> constantsRef = constants;
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
}
if (denseElementsAttr.isSplat())
return nullptr;

// In case of non-splat, create a constructor for the innermost constant from
// a piece of raw data.
std::function<llvm::Constant *(StringRef)> buildCstData;
if (type.isa<TensorType>()) {
auto vectorElementType = type.getElementType().dyn_cast<VectorType>();
if (vectorElementType && vectorElementType.getRank() == 1) {
buildCstData = [&](StringRef data) {
return llvm::ConstantDataVector::getRaw(
data, vectorElementType.getShape().back(), innermostLLVMType);
};
} else if (!vectorElementType) {
buildCstData = [&](StringRef data) {
return llvm::ConstantDataArray::getRaw(data, type.getShape().back(),
innermostLLVMType);
};
}
} else if (type.isa<VectorType>()) {
buildCstData = [&](StringRef data) {
return llvm::ConstantDataVector::getRaw(data, type.getShape().back(),
innermostLLVMType);
};
}
if (!buildCstData)
return nullptr;

// Create innermost constants and defer to the default constant creation
// mechanism for other dimensions.
SmallVector<llvm::Constant *> constants;
unsigned aggregateSize = denseElementsAttr.getType().getShape().back() *
(innermostLLVMType->getScalarSizeInBits() / 8);
constants.reserve(numAggregates);
for (unsigned i = 0; i < numAggregates; ++i) {
StringRef data(denseElementsAttr.getRawData().data() + i * aggregateSize,
aggregateSize);
constants.push_back(buildCstData(data));
}

ArrayRef<llvm::Constant *> constantsRef = constants;
return buildSequentialConstant(constantsRef, outerShape, llvmType, loc);
}

/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
/// This currently supports integer, floating point, splat and dense element
/// attributes and combinations thereof. Also, an array attribute with two
Expand Down Expand Up @@ -178,6 +264,14 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
}

// Try using raw elements data if possible.
if (llvm::Constant *result =
convertDenseElementsAttr(loc, attr.dyn_cast<DenseElementsAttr>(),
llvmType, moduleTranslation)) {
return result;
}

// Fall back to element-by-element construction otherwise.
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
assert(elementsAttr.getType().hasStaticShape());
assert(!elementsAttr.getType().getShape().empty() &&
Expand Down
32 changes: 31 additions & 1 deletion mlir/test/Target/LLVMIR/llvmir.mlir
Expand Up @@ -50,6 +50,36 @@ llvm.mlir.global internal constant @int_gep() : !llvm.ptr<i32> {
llvm.return %gepinit : !llvm.ptr<i32>
}

// CHECK{LITERAL}: @dense_float_vector = internal global <3 x float> <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00>
llvm.mlir.global internal @dense_float_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf32>) : vector<3xf32>

// CHECK{LITERAL}: @splat_float_vector = internal global <3 x float> <float 4.200000e+01, float 4.200000e+01, float 4.200000e+01>
llvm.mlir.global internal @splat_float_vector(dense<42.0> : vector<3xf32>) : vector<3xf32>

// CHECK{LITERAL}: @dense_double_vector = internal global <3 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00>
llvm.mlir.global internal @dense_double_vector(dense<[1.0, 2.0, 3.0]> : vector<3xf64>) : vector<3xf64>

// CHECK{LITERAL}: @splat_double_vector = internal global <3 x double> <double 4.200000e+01, double 4.200000e+01, double 4.200000e+01>
llvm.mlir.global internal @splat_double_vector(dense<42.0> : vector<3xf64>) : vector<3xf64>

// CHECK{LITERAL}: @dense_i64_vector = internal global <3 x i64> <i64 1, i64 2, i64 3>
llvm.mlir.global internal @dense_i64_vector(dense<[1, 2, 3]> : vector<3xi64>) : vector<3xi64>

// CHECK{LITERAL}: @splat_i64_vector = internal global <3 x i64> <i64 42, i64 42, i64 42>
llvm.mlir.global internal @splat_i64_vector(dense<42> : vector<3xi64>) : vector<3xi64>

// CHECK{LITERAL}: @dense_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>]
llvm.mlir.global internal @dense_float_vector_2d(dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>

// CHECK{LITERAL}: @splat_float_vector_2d = internal global [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]
llvm.mlir.global internal @splat_float_vector_2d(dense<42.0> : vector<2x2xf32>) : !llvm.array<2 x vector<2xf32>>

// CHECK{LITERAL}: @dense_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 1.000000e+00, float 2.000000e+00>, <2 x float> <float 3.000000e+00, float 4.000000e+00>], [2 x <2 x float>] [<2 x float> <float 5.000000e+00, float 6.000000e+00>, <2 x float> <float 7.000000e+00, float 8.000000e+00>]]
llvm.mlir.global internal @dense_float_vector_3d(dense<[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>

// CHECK{LITERAL}: @splat_float_vector_3d = internal global [2 x [2 x <2 x float>]] [[2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>], [2 x <2 x float>] [<2 x float> <float 4.200000e+01, float 4.200000e+01>, <2 x float> <float 4.200000e+01, float 4.200000e+01>]]
llvm.mlir.global internal @splat_float_vector_3d(dense<42.0> : vector<2x2x2xf32>) : !llvm.array<2 x !llvm.array<2 x vector<2xf32>>>

//
// Linkage attribute.
//
Expand All @@ -67,7 +97,7 @@ llvm.mlir.global weak @weak(42 : i32) : i32
// CHECK: @common = common global i32 0
llvm.mlir.global common @common(0 : i32) : i32
// CHECK: @appending = appending global [3 x i32] [i32 1, i32 2, i32 3]
llvm.mlir.global appending @appending(dense<[1,2,3]> : vector<3xi32>) : !llvm.array<3xi32>
llvm.mlir.global appending @appending(dense<[1,2,3]> : tensor<3xi32>) : !llvm.array<3xi32>
// CHECK: @extern_weak = extern_weak global i32
llvm.mlir.global extern_weak @extern_weak() : i32
// CHECK: @linkonce_odr = linkonce_odr global i32 42
Expand Down

0 comments on commit f9be7a7

Please sign in to comment.