-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][IR] Add implicit conversion operator to TypedValue
#164621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Add implicit conversion operator to TypedValue
#164621
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesAllow implicit conversion from Example: TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val; // this is now validFull diff: https://github.com/llvm/llvm-project/pull/164621.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 4d6d89fa69a07..af58778a0a13e 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
template <typename Ty>
struct TypedValue : Value {
using Value::Value;
+ using ValueType = Ty;
static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
+ /// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
+ /// to A.
+ template <typename ToTy,
+ typename = typename std::enable_if<std::is_assignable<
+ typename ToTy::ValueType &, Ty>::value>::type>
+ operator ToTy() const {
+ return llvm::cast<ToTy>(*this);
+ }
+
/// Return the known Type
Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..0edbc15f8a10c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1744,11 +1744,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+ return getSource();
}
TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
- return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+ return getDest();
}
bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2147038..335ca1a60f8f3 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+ TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
- .getResult());
+ .getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allGatherResult)
- .getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
- TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
- tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+ TypedValue<ShapedType> targetShard =
+ tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
- cast<TypedValue<ShapedType>>(source.getSrc()),
- sourceShardValue);
+ source.getSrc(), sourceShardValue);
}
TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice UI improvement!
Allow implicit conversion from
TypedValue<B>toTypedValue<A>ifBis assignable toA.Example:
TypedValue<MemRefType> val; TypedValue<ShapedType> shapedVal = val; // this is now valid