Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/IR/Value.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,19 @@ inline unsigned OpResultImpl::getResultNumber() const {
template <typename Ty>
struct TypedValue : Value {
using Value::Value;
using ValueType = Ty;

static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); }

/// TypedValue<B> can implicitly convert to TypedValue<A> if B is assignable
/// to A.
template <typename ToTy,
typename = typename std::enable_if<std::is_assignable<
typename ToTy::ValueType &, Ty>::value>::type>
operator ToTy() const {
return llvm::cast<ToTy>(*this);
}

/// Return the known Type
Ty getType() const { return llvm::cast<Ty>(Value::getType()); }
void setType(Ty ty) { Value::setType(ty); }
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1744,11 +1744,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
}

TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
return cast<TypedValue<PtrLikeTypeInterface>>(getSource());
return getSource();
}

TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
return cast<TypedValue<PtrLikeTypeInterface>>(getDest());
return getDest();
}

bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/Shard/Transforms/Partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
Sharding sourceSharding,
TypedValue<ShapedType> sourceShard, GridOp grid,
int64_t splitTensorAxis, GridAxis splitGridAxis) {
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
TypedValue<ShapedType> targetShard =
AllSliceOp::create(builder, sourceShard, grid,
ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis)
.getResult());
.getResult();
Sharding targetSharding = targetShardingInSplitLastAxis(
builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
return {targetShard, targetSharding};
Expand Down Expand Up @@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding(
APInt(64, splitTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allGatherResult)
.getResult());
TypedValue<ShapedType> targetShard =
tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
return {targetShard, targetSharding};
}

Expand Down Expand Up @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid,
APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
ShapedType targetShape =
shardShapedType(sourceUnshardedShape, grid, targetSharding);
TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>(
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
TypedValue<ShapedType> targetShard =
tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
return {targetShard, targetSharding};
}

Expand Down Expand Up @@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source,
auto targetSharding = target.getSharding();
ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
cast<TypedValue<ShapedType>>(source.getSrc()),
sourceShardValue);
source.getSrc(), sourceShardValue);
}

TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source,
Expand Down