Skip to content

Conversation

@matthias-springer
Copy link
Member

Allow implicit conversion from TypedValue<B> to TypedValue<A> if B is assignable to A.

Example:

TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val;  // this is now valid

@llvmbot
Copy link
Member

llvmbot commented Oct 22, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Allow implicit conversion from TypedValue&lt;B&gt; to TypedValue&lt;A&gt; if B is assignable to A.

Example:

TypedValue&lt;MemRefType&gt; val;
TypedValue&lt;ShapedType&gt; shapedVal = val;  // this is now valid

Full diff: https://github.com/llvm/llvm-project/pull/164621.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/Value.h (+10)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Shard/Transforms/Partition.cpp (+7-9)
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 4d6d89fa69a07..af58778a0a13e 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
 template <typename Ty>
 struct TypedValue : Value {
   using Value::Value;
+  using ValueType = Ty;
 
   static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }
 
+  /// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
+  /// to A.
+  template <typename ToTy,
+            typename = typename std::enable_if<std::is_assignable<
+                typename ToTy::ValueType &, Ty>::value>::type>
+  operator ToTy() const {
+    return llvm::cast<ToTy>(*this);
+  }
+
   /// Return the known Type
   Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
   void setType(Ty ty) { Value::setType(ty); }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 94947b760251e..0edbc15f8a10c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1744,11 +1744,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
 }
 
 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
-  return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
+  return getSource();
 }
 
 TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
-  return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
+  return getDest();
 }
 
 bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
index 5dc61a2147038..335ca1a60f8f3 100644
--- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp
@@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
                           Sharding sourceSharding,
                           TypedValue<ShapedType> sourceShard, GridOp grid,
                           int64_t splitTensorAxis, GridAxis splitGridAxis) {
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
+  TypedValue<ShapedType> targetShard =
       AllSliceOp::create(builder, sourceShard, grid,
                          ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
-          .getResult());
+          .getResult();
   Sharding targetSharding = targetShardingInSplitLastAxis(
       builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
   return {targetShard, targetSharding};
@@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
       APInt(64, splitTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, grid, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      tensor::CastOp::create(builder, targetShape, allGatherResult)
-          .getResult());
+  TypedValue<ShapedType> targetShard =
+      tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
   return {targetShard, targetSharding};
 }
 
@@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
       APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
   ShapedType targetShape =
       shardShapedType(sourceUnshardedShape, grid, targetSharding);
-  TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
-      tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
+  TypedValue<ShapedType> targetShard =
+      tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
   return {targetShard, targetSharding};
 }
 
@@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
   auto targetSharding = target.getSharding();
   ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
   return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
-                 cast<TypedValue<ShapedType>>(source.getSrc()),
-                 sourceShardValue);
+                 source.getSrc(), sourceShardValue);
 }
 
 TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice UI improvement!

@matthias-springer matthias-springer merged commit fb4c05c into main Oct 23, 2025
14 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/implicit_conv_typed_value branch October 23, 2025 07:24
mikolaj-pirog pushed a commit to mikolaj-pirog/llvm-project that referenced this pull request Oct 23, 2025
)

Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B`
is assignable to `A`.

Example:
```c++
TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val;  // this is now valid
```
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
)

Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B`
is assignable to `A`.

Example:
```c++
TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val;  // this is now valid
```
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
)

Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B`
is assignable to `A`.

Example:
```c++
TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val;  // this is now valid
```
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
)

Allow implicit conversion from `TypedValue<B>` to `TypedValue<A>` if `B`
is assignable to `A`.

Example:
```c++
TypedValue<MemRefType> val;
TypedValue<ShapedType> shapedVal = val;  // this is now valid
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:memref mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants