diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 4d6d89fa69a07..af58778a0a13e 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const { template struct TypedValue : Value { using Value::Value; + using ValueType = Ty; static bool classof(Value value) { return llvm::isa(value.getType()); } + /// TypedValue can implicitly convert to TypedValue if B is assignable + /// to A. + template ::value>::type> + operator ToTy() const { + return llvm::cast(*this); + } + /// Return the known Type Ty getType() const { return llvm::cast(Value::getType()); } void setType(Ty ty) { Value::setType(ty); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94947b760251e..0edbc15f8a10c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1744,11 +1744,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { } TypedValue MemorySpaceCastOp::getSourcePtr() { - return cast>(getSource()); + return getSource(); } TypedValue MemorySpaceCastOp::getTargetPtr() { - return cast>(getDest()); + return getDest(); } bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dc61a2147038..335ca1a60f8f3 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { - TypedValue targetShard = cast>( + TypedValue targetShard = AllSliceOp::create(builder, sourceShard, grid, ArrayRef(splitGridAxis), splitTensorAxis) - .getResult()); + .getResult(); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; @@ -204,9 +204,8 @@ static std::tuple, Sharding> unsplitLastAxisInResharding( APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue targetShard = cast>( - tensor::CastOp::create(builder, targetShape, allGatherResult) - .getResult()); + TypedValue targetShard = + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); return {targetShard, targetSharding}; } @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue targetShard = cast>( - tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); + TypedValue targetShard = + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult(); return {targetShard, targetSharding}; } @@ -510,8 +509,7 @@ TypedValue reshard(OpBuilder &builder, GridOp grid, ShardOp source, auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, - cast>(source.getSrc()), - sourceShardValue); + source.getSrc(), sourceShardValue); } TypedValue reshard(OpBuilder &builder, ShardOp source,