Skip to content

Commit

Permalink
Fold Tensor.extract_slice into a constant splat.
Browse files Browse the repository at this point in the history
Fold arith.extract_slice into arith.constant when the source is a constant
splat and the result type is statically shaped.
  • Loading branch information
okkwon committed Feb 22, 2022
1 parent 210bb04 commit f79f430
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/BuiltinAttributes.h
Expand Up @@ -655,6 +655,11 @@ class DenseElementsAttr : public Attribute {
/// same total number of elements as well as element type.
DenseElementsAttr reshape(ShapedType newType);

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but with a different shape for a splat type. The new type must
/// have the same element type.
DenseElementsAttr resizeSplat(ShapedType newType);

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has bitcast elements to 'newElType'. The new type must have
/// the same bitwidth as the current element type.
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Expand Up @@ -1227,7 +1227,12 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
return {};
}

OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute>) {
OpFoldResult ExtractSliceOp::fold(ArrayRef<Attribute> operands) {
if (auto splat = operands[0].dyn_cast_or_null<SplatElementsAttr>()) {
auto resultType = result().getType().cast<ShapedType>();
if (resultType.hasStaticShape())
return splat.resizeSplat(resultType);
}
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->source();
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/IR/BuiltinAttributes.cpp
Expand Up @@ -967,6 +967,18 @@ DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat());
}

DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
assert(isSplat() && "expected a splat type");

ShapedType curType = getType();
if (curType == newType)
return *this;

assert(newType.getElementType() == curType.getElementType() &&
"expected the same element type");
return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true);
}

/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has bitcast elements such that it is now 'newType'. The new
/// type must have the same shape and element types of the same bitwidth as the
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Expand Up @@ -621,6 +621,17 @@ func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>,

// -----

// CHECK-LABEL: func @fold_extract_constant_splat
// CHECK-NOT: tensor.extract_slice
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
%cst = arith.constant dense<42> : tensor<1024x1024xi32>
%1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
return %1 : tensor<4x4xi32>
}

// -----

// CHECK-LABEL: func @fold_overlapping_insert
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
Expand Down

0 comments on commit f79f430

Please sign in to comment.