-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Adopt cast function objects. NFC. #168228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-core Author: Jakub Kuderski (kuhar) ChangesThese were added in #165803. Full diff: https://github.com/llvm/llvm-project/pull/168228.diff 5 Files Affected:
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index eaad8a87aab9b..6636f0ea73ec9 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -159,9 +159,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
return wrap(DIExpressionAttr::get(
unwrap(ctx),
- llvm::map_to_vector(
- unwrapList(nOperations, operations, attrStorage),
- [](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
+ llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage),
+ llvm::CastTo<DIExpressionElemAttr>)));
}
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -202,7 +201,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
cast<DIExpressionAttr>(unwrap(allocated)),
cast<DIExpressionAttr>(unwrap(associated)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIDerivedTypeAttrGet(
@@ -308,7 +307,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
return wrap(DISubroutineTypeAttr::get(
unwrap(ctx), callingConvention,
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
- [](Attribute a) { return cast<DITypeAttr>(a); })));
+ llvm::CastTo<DITypeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) {
@@ -338,10 +337,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet(
cast<DISubroutineTypeAttr>(unwrap(type)),
llvm::map_to_vector(
unwrapList(nRetainedNodes, retainedNodes, nodesStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); }),
+ llvm::CastTo<DINodeAttr>),
llvm::map_to_vector(
unwrapList(nAnnotations, annotations, annotationsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) {
@@ -398,7 +397,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet(
cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line,
cast<StringAttr>(unwrap(name)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdce72ea4..b1893f0868ac5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
auto tensorTypes =
- llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- }));
+ llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
int64_t concatRank = tensorTypes[0].getRank();
// The concatenation dim must be in the range [0, rank).
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 8859541c78c91..24b048795b136 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
- return llvm::to_vector(llvm::map_range(
- range, [](transform::MappedValue value) { return cast<T>(value); }));
+ return llvm::map_to_vector(range, llvm::CastTo<T>);
}
void transform::detail::setApplyToOneResults(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bbd7733e89c29..4455811a2e681 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -926,8 +926,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
getAsOpFoldResult(origOffsets));
- newCoods = llvm::to_vector(llvm::map_range(
- ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
return newCoods;
}
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index e438631ffe1f5..199744d208143 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
- auto shapedTypes = llvm::map_to_vector<8>(
- types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
+ auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>);
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
|
|
@llvm/pr-subscribers-mlir-gpu Author: Jakub Kuderski (kuhar) ChangesThese were added in #165803. Full diff: https://github.com/llvm/llvm-project/pull/168228.diff 5 Files Affected:
diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp
index eaad8a87aab9b..6636f0ea73ec9 100644
--- a/mlir/lib/CAPI/Dialect/LLVM.cpp
+++ b/mlir/lib/CAPI/Dialect/LLVM.cpp
@@ -159,9 +159,8 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations,
return wrap(DIExpressionAttr::get(
unwrap(ctx),
- llvm::map_to_vector(
- unwrapList(nOperations, operations, attrStorage),
- [](Attribute a) { return cast<DIExpressionElemAttr>(a); })));
+ llvm::map_to_vector(unwrapList(nOperations, operations, attrStorage),
+ llvm::CastTo<DIExpressionElemAttr>)));
}
MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) {
@@ -202,7 +201,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet(
cast<DIExpressionAttr>(unwrap(allocated)),
cast<DIExpressionAttr>(unwrap(associated)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIDerivedTypeAttrGet(
@@ -308,7 +307,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx,
return wrap(DISubroutineTypeAttr::get(
unwrap(ctx), callingConvention,
llvm::map_to_vector(unwrapList(nTypes, types, attrStorage),
- [](Attribute a) { return cast<DITypeAttr>(a); })));
+ llvm::CastTo<DITypeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) {
@@ -338,10 +337,10 @@ MlirAttribute mlirLLVMDISubprogramAttrGet(
cast<DISubroutineTypeAttr>(unwrap(type)),
llvm::map_to_vector(
unwrapList(nRetainedNodes, retainedNodes, nodesStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); }),
+ llvm::CastTo<DINodeAttr>),
llvm::map_to_vector(
unwrapList(nAnnotations, annotations, annotationsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) {
@@ -398,7 +397,7 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet(
cast<DINodeAttr>(unwrap(entity)), cast<DIFileAttr>(unwrap(file)), line,
cast<StringAttr>(unwrap(name)),
llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage),
- [](Attribute a) { return cast<DINodeAttr>(a); })));
+ llvm::CastTo<DINodeAttr>)));
}
MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name,
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 110bfdce72ea4..b1893f0868ac5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -551,9 +551,7 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
auto tensorTypes =
- llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
- return llvm::cast<RankedTensorType>(type);
- }));
+ llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
int64_t concatRank = tensorTypes[0].getRank();
// The concatenation dim must be in the range [0, rank).
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index 8859541c78c91..24b048795b136 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -1495,8 +1495,7 @@ transform::detail::checkApplyToOne(Operation *transformOp,
template <typename T>
static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
- return llvm::to_vector(llvm::map_range(
- range, [](transform::MappedValue value) { return cast<T>(value); }));
+ return llvm::map_to_vector(range, llvm::CastTo<T>);
}
void transform::detail::setApplyToOneResults(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bbd7733e89c29..4455811a2e681 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -926,8 +926,7 @@ static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
getAsOpFoldResult(origOffsets));
- newCoods = llvm::to_vector(llvm::map_range(
- ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ newCoods = llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
return newCoods;
}
diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp
index e438631ffe1f5..199744d208143 100644
--- a/mlir/lib/IR/TypeUtilities.cpp
+++ b/mlir/lib/IR/TypeUtilities.cpp
@@ -118,8 +118,7 @@ LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
/// dims are equal. The element type does not matter.
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
- auto shapedTypes = llvm::map_to_vector<8>(
- types, [](auto type) { return llvm::dyn_cast<ShapedType>(type); });
+ auto shapedTypes = llvm::map_to_vector<8>(types, llvm::DynCastTo<ShapedType>);
// Return failure if some, but not all are not shaped. Return early if none
// are shaped also.
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
|
kazutakahirata
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
These were added in #165803.