Skip to content

Commit 0813700

Browse files
[mlir][NFC] Cleanup: Move helper functions to StaticValueUtils
Reduce code duplication: Move various helper functions, that are duplicated in TensorDialect, MemRefDialect, LinalgDialect, StandardDialect, into a new StaticValueUtils.cpp. Differential Revision: https://reviews.llvm.org/D104687
1 parent 81f6d7c commit 0813700

File tree

17 files changed

+159
-164
lines changed

17 files changed

+159
-164
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,13 +269,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
269269
// Return true if low padding is guaranteed to be 0.
270270
bool hasZeroLowPad() {
271271
return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) {
272-
return mlir::isEqualConstantInt(ofr, 0);
272+
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
273273
});
274274
}
275275
// Return true if high padding is guaranteed to be 0.
276276
bool hasZeroHighPad() {
277277
return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) {
278-
return mlir::isEqualConstantInt(ofr, 0);
278+
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
279279
});
280280
}
281281
}];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1414
#include "mlir/Dialect/SCF/Utils.h"
1515
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1617
#include "mlir/Dialect/Vector/VectorOps.h"
1718
#include "mlir/IR/Identifier.h"
1819
#include "mlir/IR/PatternMatch.h"

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
114114
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
115115
const APFloat &rhs);
116116

117-
/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
118-
/// IntegerAttr, return the integer.
119-
llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
120-
121-
/// Return true if ofr and value are the same integer.
122-
/// Ignore integer bitwidth and type mismatch that come from the fact there is
123-
/// no IndexAttr and that IndexType has no bitwidth.
124-
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
125-
126-
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
127-
/// or the same SSA value.
128-
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
129-
/// no IndexAttr and that IndexType have no bitwidth.
130-
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
131-
132117
/// Returns the identity value attribute associated with an AtomicRMWKind op.
133118
Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
134119
OpBuilder &builder, Location loc);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines utilities for dealing with static values, e.g.,
10+
// converting back and forth between Value and OpFoldResult. Such functionality
11+
// is used in multiple dialects.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
16+
#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
17+
18+
#include "mlir/IR/OpDefinition.h"
19+
#include "mlir/Support/LLVM.h"
20+
#include "llvm/ADT/SmallVector.h"
21+
22+
namespace mlir {
23+
24+
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
25+
/// it is a Value or into `staticVec` if it is an IntegerAttr.
26+
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
27+
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
28+
/// come from an AttrSizedOperandSegments trait.
29+
void dispatchIndexOpFoldResult(OpFoldResult ofr,
30+
SmallVectorImpl<Value> &dynamicVec,
31+
SmallVectorImpl<int64_t> &staticVec,
32+
int64_t sentinel);
33+
34+
/// Helper function to dispatch multiple OpFoldResults into either the
35+
/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs).
36+
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
37+
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
38+
/// come from an AttrSizedOperandSegments trait.
39+
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
40+
SmallVectorImpl<Value> &dynamicVec,
41+
SmallVectorImpl<int64_t> &staticVec,
42+
int64_t sentinel);
43+
44+
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
45+
SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr);
46+
47+
/// If ofr is a constant integer or an IntegerAttr, return the integer.
48+
Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
49+
50+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
51+
/// or the same SSA value.
52+
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
53+
/// no IndexAttr and that IndexType have no bitwidth.
54+
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
55+
56+
} // namespace mlir
57+
58+
#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

mlir/include/mlir/Interfaces/ViewLikeInterface.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
1414
#define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_
1515

16+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1617
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinAttributes.h"
1819
#include "mlir/IR/BuiltinTypes.h"
@@ -30,8 +31,6 @@ struct Range {
3031

3132
class OffsetSizeAndStrideOpInterface;
3233

33-
bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
34-
3534
namespace detail {
3635
LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
3736

mlir/include/mlir/Interfaces/ViewLikeInterface.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
444444
/*methodBody=*/"",
445445
/*defaultImplementation=*/[{
446446
return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
447-
return ::mlir::isEqualConstantInt(ofr, 1);
447+
return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(1);
448448
});
449449
}]
450450
>,
@@ -456,7 +456,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
456456
/*methodBody=*/"",
457457
/*defaultImplementation=*/[{
458458
return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
459-
return ::mlir::isEqualConstantInt(ofr, 0);
459+
return ::mlir::getConstantIntValue(ofr) == static_cast<int64_t>(0);
460460
});
461461
}]
462462
>,

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Math/IR/Math.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2222
#include "mlir/Dialect/StandardOps/IR/Ops.h"
23+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2324
#include "mlir/IR/Attributes.h"
2425
#include "mlir/IR/BlockAndValueMapping.h"
2526
#include "mlir/IR/Builders.h"
@@ -3388,14 +3389,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
33883389
}
33893390
};
33903391

3391-
/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr.
3392-
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
3393-
return llvm::to_vector<4>(
3394-
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
3395-
return a.cast<IntegerAttr>().getInt();
3396-
}));
3397-
}
3398-
33993392
/// Conversion pattern that transforms a subview op into:
34003393
/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
34013394
/// 2. Updates to the descriptor to introduce the data ptr, offset, size

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021
#include "mlir/IR/AffineExprVisitor.h"
2122
#include "mlir/IR/Matchers.h"
2223
#include "mlir/IR/OpImplementation.h"
@@ -116,24 +117,6 @@ static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
116117
}));
117118
}
118119

119-
/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if
120-
/// it is a Value or into `staticVec` if it is an IntegerAttr.
121-
/// In the case of a Value, a copy of the `sentinel` value is also pushed to
122-
/// `staticVec`. This is useful to extract mixed static and dynamic entries that
123-
/// come from an AttrSizedOperandSegments trait.
124-
static void dispatchIndexOpFoldResult(OpFoldResult ofr,
125-
SmallVectorImpl<Value> &dynamicVec,
126-
SmallVectorImpl<int64_t> &staticVec,
127-
int64_t sentinel) {
128-
if (auto v = ofr.dyn_cast<Value>()) {
129-
dynamicVec.push_back(v);
130-
staticVec.push_back(sentinel);
131-
return;
132-
}
133-
APInt apInt = ofr.dyn_cast<Attribute>().cast<IntegerAttr>().getValue();
134-
staticVec.push_back(apInt.getSExtValue());
135-
}
136-
137120
/// This is a common class used for patterns of the form
138121
/// ```
139122
/// someop(memrefcast(%src)) -> someop(%src)
@@ -819,14 +802,6 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim(
819802
// PadTensorOp
820803
//===----------------------------------------------------------------------===//
821804

822-
/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
823-
static SmallVector<int64_t, 4> extractFromI64ArrayAttr(Attribute attr) {
824-
return llvm::to_vector<4>(
825-
llvm::map_range(attr.cast<ArrayAttr>(), [](Attribute a) -> int64_t {
826-
return a.cast<IntegerAttr>().getInt();
827-
}));
828-
}
829-
830805
static LogicalResult verify(PadTensorOp op) {
831806
auto sourceType = op.source().getType().cast<RankedTensorType>();
832807
auto resultType = op.result().getType().cast<RankedTensorType>();

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@
110110
#include "mlir/Dialect/Linalg/Passes.h"
111111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
112112
#include "mlir/Dialect/SCF/SCF.h"
113+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
113114
#include "mlir/Dialect/Vector/VectorOps.h"
114115
#include "mlir/IR/Operation.h"
115116
#include "mlir/Pass/Pass.h"

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,8 @@ struct GenericPadTensorOpVectorizationPattern
814814
readInBounds.push_back(false);
815815
// Write is out-of-bounds if low padding > 0.
816816
writeInBounds.push_back(
817-
isEqualConstantIntOrValue(padOp.getMixedLowPad()[i],
818-
rewriter.getIndexAttr(0)));
817+
getConstantIntValue(padOp.getMixedLowPad()[i]) ==
818+
static_cast<int64_t>(0));
819819
} else {
820820
// Neither source nor result dim of padOp is static. Cannot vectorize
821821
// the copy.
@@ -1098,9 +1098,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern
10981098
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
10991099
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
11001100
if (!llvm::all_of(
1101-
llvm::zip(insertOp.getMixedSizes(), expectedSizes),
1102-
[](auto it) { return isEqualConstantInt(std::get<0>(it),
1103-
std::get<1>(it)); }))
1101+
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
1102+
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
1103+
}))
11041104
return failure();
11051105

11061106
// Generate TransferReadOp: Read entire source tensor and add high padding.

0 commit comments

Comments
 (0)