26 changes: 13 additions & 13 deletions mlir/include/mlir/IR/FunctionInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// has less parameters we drop the extra attributes, if there are more
/// parameters they won't have any attributes.
void setType(Type newType) {
function_interface_impl::setFunctionType(this->getOperation(), newType);
function_interface_impl::setFunctionType($_op, newType);
}

//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -316,7 +316,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
Type newType = $_op.getTypeWithArgsAndResults(
argIndices, argTypes, /*resultIndices=*/{}, /*resultTypes=*/{});
function_interface_impl::insertFunctionArguments(
this->getOperation(), argIndices, argTypes, argAttrs, argLocs,
$_op, argIndices, argTypes, argAttrs, argLocs,
originalNumArgs, newType);
}

Expand All @@ -336,7 +336,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
Type newType = $_op.getTypeWithArgsAndResults(
/*argIndices=*/{}, /*argTypes=*/{}, resultIndices, resultTypes);
function_interface_impl::insertFunctionResults(
this->getOperation(), resultIndices, resultTypes, resultAttrs,
$_op, resultIndices, resultTypes, resultAttrs,
originalNumResults, newType);
}

Expand All @@ -351,7 +351,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void eraseArguments(const BitVector &argIndices) {
Type newType = $_op.getTypeWithoutArgs(argIndices);
function_interface_impl::eraseFunctionArguments(
this->getOperation(), argIndices, newType);
$_op, argIndices, newType);
}

/// Erase a single result at `resultIndex`.
Expand All @@ -365,7 +365,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void eraseResults(const BitVector &resultIndices) {
Type newType = $_op.getTypeWithoutResults(resultIndices);
function_interface_impl::eraseFunctionResults(
this->getOperation(), resultIndices, newType);
$_op, resultIndices, newType);
}

/// Return the type of this function with the specified arguments and
Expand Down Expand Up @@ -414,7 +414,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {

/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
return function_interface_impl::getArgAttrs(this->getOperation(), index);
return function_interface_impl::getArgAttrs($_op, index);
}

/// Return an ArrayAttr containing all argument attribute dictionaries of
Expand Down Expand Up @@ -464,11 +464,11 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
}
void setAllArgAttrs(ArrayRef<DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumArguments());
function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
void setAllArgAttrs(ArrayRef<Attribute> attributes) {
assert(attributes.size() == $_op.getNumArguments());
function_interface_impl::setAllArgAttrDicts(this->getOperation(), attributes);
function_interface_impl::setAllArgAttrDicts($_op, attributes);
}
void setAllArgAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumArguments());
Expand Down Expand Up @@ -503,7 +503,7 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {

/// Return all of the attributes for the result at 'index'.
ArrayRef<NamedAttribute> getResultAttrs(unsigned index) {
return function_interface_impl::getResultAttrs(this->getOperation(), index);
return function_interface_impl::getResultAttrs($_op, index);
}

/// Return an ArrayAttr containing all result attribute dictionaries of this
Expand Down Expand Up @@ -554,12 +554,12 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
void setAllResultAttrs(ArrayRef<DictionaryAttr> attributes) {
assert(attributes.size() == $_op.getNumResults());
function_interface_impl::setAllResultAttrDicts(
this->getOperation(), attributes);
$_op, attributes);
}
void setAllResultAttrs(ArrayRef<Attribute> attributes) {
assert(attributes.size() == $_op.getNumResults());
function_interface_impl::setAllResultAttrDicts(
this->getOperation(), attributes);
$_op, attributes);
}
void setAllResultAttrs(ArrayAttr attributes) {
assert(attributes.size() == $_op.getNumResults());
Expand Down Expand Up @@ -589,15 +589,15 @@ def FunctionOpInterface : OpInterface<"FunctionOpInterface", [Symbol]> {
/// attribute is returned.
DictionaryAttr getArgAttrDict(unsigned index) {
assert(index < $_op.getNumArguments() && "invalid argument number");
return function_interface_impl::getArgAttrDict(this->getOperation(), index);
return function_interface_impl::getArgAttrDict($_op, index);
}

/// Returns the dictionary attribute corresponding to the result at 'index'.
/// If there are no result attributes at 'index', a null attribute is
/// returned.
DictionaryAttr getResultAttrDict(unsigned index) {
assert(index < $_op.getNumResults() && "invalid result number");
return function_interface_impl::getResultAttrDict(this->getOperation(), index);
return function_interface_impl::getResultAttrDict($_op, index);
}
}];

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Support/InterfaceSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class Interface : public BaseType {
};

/// Construct an interface from an instance of the value type.
Interface(ValueT t = ValueT())
explicit Interface(ValueT t = ValueT())
: BaseType(t),
conceptImpl(t ? ConcreteType::getInterfaceFor(t) : nullptr) {
assert((!t || conceptImpl) &&
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,22 +296,22 @@ LogicalResult MulIExtendedOpLowering<ArithMulOp, IsSigned>::matchAndRewrite(
// on operands zero-extended to i(2*N) bits, and truncate the results back to
// iN types.
if (!resultType.isa<LLVM::LLVMArrayType>()) {
Type wideType;
// Shift amount necessary to extract the high bits from widened result.
Attribute shiftValAttr;
TypedAttr shiftValAttr;

if (auto intTy = resultType.dyn_cast<IntegerType>()) {
unsigned resultBitwidth = intTy.getWidth();
wideType = rewriter.getIntegerType(resultBitwidth * 2);
shiftValAttr = rewriter.getIntegerAttr(wideType, resultBitwidth);
auto attrTy = rewriter.getIntegerType(resultBitwidth * 2);
shiftValAttr = rewriter.getIntegerAttr(attrTy, resultBitwidth);
} else {
auto vecTy = resultType.cast<VectorType>();
unsigned resultBitwidth = vecTy.getElementTypeBitWidth();
wideType = VectorType::get(vecTy.getShape(),
rewriter.getIntegerType(resultBitwidth * 2));
auto attrTy = VectorType::get(
vecTy.getShape(), rewriter.getIntegerType(resultBitwidth * 2));
shiftValAttr = SplatElementsAttr::get(
wideType, APInt(resultBitwidth * 2, resultBitwidth));
attrTy, APInt(resultBitwidth * 2, resultBitwidth));
}
Type wideType = shiftValAttr.getType();
assert(LLVM::isCompatibleType(wideType) &&
"LLVM dialect should support all signless integer types");

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Type matchContainerType(Type element, Type container) {
return element;
}

Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) {
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
Type eTy = shapedTy.getElementType();
APInt valueInt(eTy.getIntOrFloatBitWidth(), value);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation,

// Returns the constant initial value for a given reduction operation. The
// attribute type varies depending on the element type required.
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
return rewriter.getFloatAttr(elementTy, 0.0);
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace mlir;
using namespace mlir::tosa;

static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
Attribute padAttr, OpBuilder &rewriter) {
TypedAttr padAttr, OpBuilder &rewriter) {
// Input should be padded if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
Expand Down Expand Up @@ -224,7 +224,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto weightShape = weightTy.getShape();

// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
int64_t iZp = quantizationInfo.getInputZp();
Expand Down Expand Up @@ -269,7 +269,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);

Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Expand Down Expand Up @@ -391,7 +391,7 @@ class DepthwiseConvConverter
auto resultShape = resultTy.getShape();

// Apply padding as necessary.
Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
Expand Down Expand Up @@ -439,7 +439,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, linalgConvTy.getShape(), resultETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Expand Down Expand Up @@ -604,7 +604,7 @@ class FullyConnectedConverter
loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);

// When quantized, the input elemeny type is not the same as the output
Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
Expand Down Expand Up @@ -688,7 +688,7 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
SmallVector<Value> dynamicDims = *dynamicDimsOr;

// Determine what the initial value needs to be for the max pool op.
Attribute initialAttr;
TypedAttr initialAttr;
if (resultETy.isF32())
initialAttr = rewriter.getFloatAttr(
resultETy,
Expand Down Expand Up @@ -768,10 +768,10 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
pad.resize(2, 0);
llvm::append_range(pad, op.getPad());
pad.resize(pad.size() + 2, 0);
Attribute padAttr = rewriter.getZeroAttr(inElementTy);
TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);

Attribute initialAttr = rewriter.getZeroAttr(accETy);
auto initialAttr = rewriter.getZeroAttr(accETy);
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);

ArrayRef<int64_t> kernel = op.getKernel();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class PadConverter : public OpRewritePattern<tosa::PadOp> {
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(), ValueRange({}));
} else {
Attribute constantAttr;
TypedAttr constantAttr;
if (elementTy.isa<FloatType>()) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
} else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ void AffineDialect::initialize() {
Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

/// A utility function to check if a value is defined at the top level of an
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1677,8 +1677,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType,
dynIdx++;
} else {
// Create ConstantOp for static dimension.
Attribute constantAttr =
b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]);
inAffineApply.emplace_back(
b.create<arith::ConstantOp>(allocOp->getLoc(), constantAttr));
}
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def AddIntAttrs : NativeCodeCall<"addIntegerAttrs($_builder, $0, $1, $2)">;
// Subtract two integer attributes and createa a new one with the result.
def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;

class cast<string type> : NativeCodeCall<"::mlir::cast<" # type # ">($0)">;

//===----------------------------------------------------------------------===//
// AddIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -320,8 +322,8 @@ def TruncationMatchesShiftAmount :
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
(Arith_ShRSIOp $x, (ConstantLikeMatcher AnyAttr:$c0))),
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp $c0))),
(Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;

// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Arith/IR/ArithDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ void arith::ArithDialect::initialize() {
Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, value, type);
return ConstantOp::materialize(builder, value, type, loc);
}
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,13 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
}

ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
Type type, Location loc) {
if (isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
return nullptr;
}

OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }

void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
Expand Down Expand Up @@ -2306,7 +2313,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//

/// Returns the identity value attribute associated with an AtomicRMWKind op.
Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
OpBuilder &builder, Location loc) {
switch (kind) {
case AtomicRMWKind::maxf:
Expand Down Expand Up @@ -2355,7 +2362,7 @@ Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
/// Returns the identity value associated with an AtomicRMWKind op.
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
OpBuilder &builder, Location loc) {
Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
auto attr = getIdentityValueAttr(op, resultType, builder, loc);
return builder.create<arith::ConstantOp>(loc, attr);
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static Type reduceInnermostDim(VectorType type) {
static Value createScalarOrSplatConstant(ConversionPatternRewriter &rewriter,
Location loc, Type type,
const APInt &value) {
Attribute attr;
TypedAttr attr;
if (auto intTy = type.dyn_cast<IntegerType>()) {
attr = rewriter.getIntegerAttr(type, value);
} else {
Expand Down Expand Up @@ -989,7 +989,7 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt);

int64_t pow2Int = int64_t(1) << newBitWidth;
Attribute pow2Attr =
TypedAttr pow2Attr =
rewriter.getFloatAttr(resultElemTy, static_cast<double>(pow2Int));
if (auto vecTy = dyn_cast<VectorType>(resultTy))
pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr);
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
}
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return nullptr;
return arith::ConstantOp::materialize(builder, value, type, loc);
}

#define GET_ATTRDEF_CLASSES
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ LogicalResult emitc::ConstantOp::verify() {
if (getValueAttr().isa<emitc::OpaqueAttr>())
return success();

TypedAttr value = getValueAttr();
auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
Expand Down Expand Up @@ -177,7 +177,7 @@ LogicalResult emitc::VariableOp::verify() {
if (getValueAttr().isa<emitc::OpaqueAttr>())
return success();

TypedAttr value = getValueAttr();
auto value = cast<TypedAttr>(getValueAttr());
Type type = getType();
if (!value.getType().isa<NoneType>() && type != value.getType())
return emitOpError() << "requires attribute's type (" << value.getType()
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl(
llvm::dbgs() << "\n");

// Step 2. sort the values by the corresponding DeviceMappingAttrInterface.
auto comparator = [&](DeviceMappingAttrInterface a,
DeviceMappingAttrInterface b) -> bool {
return a.getMappingId() < b.getMappingId();
auto comparator = [&](Attribute a, Attribute b) -> bool {
return cast<DeviceMappingAttrInterface>(a).getMappingId() <
cast<DeviceMappingAttrInterface>(b).getMappingId();
};
SmallVector<int64_t> forallMappingSizes =
getValuesSortedByKey(forallMappingAttrs, tmpMappingSizes, comparator);
Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,10 +389,7 @@ class RegionBuilderHelper {
OpBuilder builder = getBuilder();
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
Type type = NoneType::get(builder.getContext());
if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
type = typedAttr.getType();
return builder.create<arith::ConstantOp>(loc, type, valueAttr);
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}

Value index(int64_t dim) {
Expand Down Expand Up @@ -2109,5 +2106,5 @@ void LinalgDialect::getCanonicalizationPatterns(
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,8 +1688,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
}

// Create a constant scalar value from the splat constant.
Value scalarConstant = rewriter.create<arith::ConstantOp>(
def->getLoc(), constantAttr, constantAttr.getType());
Value scalarConstant =
rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr);

SmallVector<Value> outputOperands = genericOp.getOutputs();
auto fusedOp = rewriter.create<GenericOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
staticStridesVector));
}

Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes);
LinalgOp clonedOp = clone(b, producer, resultTypes, clonedShapes);

// Shift all IndexOp results by the tile offset.
SmallVector<OpFoldResult> allIvs = llvm::to_vector(
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using namespace linalg;
static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
LinalgOp consumerOp = consumerOperand->getOwner();
auto consumerOp = cast<LinalgOp>(consumerOperand->getOwner());
AffineMap indexingMap = consumerOp.getMatchingIndexingMap(consumerOperand);

// Search the slice dimensions tiled by a tile loop dimension.
Expand All @@ -65,7 +65,7 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
static SmallVector<int64_t>
getTiledProducerLoops(OpResult producerResult,
ArrayRef<int64_t> tiledSliceDimIndices) {
LinalgOp producerOp = producerResult.getOwner();
auto producerOp = cast<LinalgOp>(producerResult.getOwner());

// Get the indexing map of the `producerOp` output operand that matches
// ´producerResult´.
Expand Down Expand Up @@ -137,7 +137,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
b.setInsertionPointAfter(sliceOp);

// Get the producer.
LinalgOp producerOp = producerResult.getOwner();
auto producerOp = cast<LinalgOp>(producerResult.getOwner());
Location loc = producerOp.getLoc();

// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
Expand Down Expand Up @@ -345,7 +345,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
return failure();

// Check `sliceOp` and `consumerOp` are in the same block.
LinalgOp consumerOp = consumerOpOperand->getOwner();
auto consumerOp = cast<LinalgOp>(consumerOpOperand->getOwner());
if (sliceOp->getBlock() != rootOp->getBlock() ||
consumerOp->getBlock() != rootOp->getBlock())
return failure();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
"expected the number of loops and induction variables to match");
// Replace the index operations in the body of the innermost loop op.
if (!loopOps.empty()) {
LoopLikeOpInterface loopOp = loopOps.back();
auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
return b.notifyMatchFailure(op, "Cannot match the reduction pattern");

Operation *reductionOp = combinerOps[0];
std::optional<Attribute> identity = getNeutralElement(reductionOp);
std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
if (!identity.has_value())
return b.notifyMatchFailure(op, "Unknown identity value for the reduction");

Expand Down Expand Up @@ -272,9 +272,9 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
return b.notifyMatchFailure(op, "cannot match a reduction pattern");

SmallVector<Attribute> neutralElements;
SmallVector<TypedAttr> neutralElements;
for (Operation *reductionOp : combinerOps) {
std::optional<Attribute> neutralElement = getNeutralElement(reductionOp);
std::optional<TypedAttr> neutralElement = getNeutralElement(reductionOp);
if (!neutralElement.has_value())
return b.notifyMatchFailure(op, "cannot find neutral element.");
neutralElements.push_back(*neutralElement);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ struct LinalgOpPartialReductionInterface
return op->emitOpError("Failed to anaysis the reduction operation.");

Operation *reductionOp = combinerOps[0];
std::optional<Attribute> identity = getNeutralElement(reductionOp);
std::optional<TypedAttr> identity = getNeutralElement(reductionOp);
if (!identity.has_value())
return op->emitOpError(
"Failed to get an identity value for the reduction operation.");
Expand Down Expand Up @@ -328,8 +328,8 @@ struct LinalgOpPartialReductionInterface

// Step 1: Extract a slice of the input operands.
SmallVector<Value> valuesToTile = linalgOp.getDpsInputOperands();
SmallVector<Value, 4> tiledOperands =
makeTiledShapes(b, loc, op, valuesToTile, offsets, sizes, {}, true);
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);

// Step 2: Extract the accumulator operands
SmallVector<OpFoldResult> strides(offsets.size(), b.getIndexAttr(1));
Expand Down
7 changes: 2 additions & 5 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,8 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox(
return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
}
Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()];
Type paddingType = rewriter.getType<NoneType>();
if (auto typedAttr = paddingAttr.dyn_cast<TypedAttr>())
paddingType = typedAttr.getType();
Value paddingValue = rewriter.create<arith::ConstantOp>(
opToPad.getLoc(), paddingType, paddingAttr);
opToPad.getLoc(), cast<TypedAttr>(paddingAttr));

// Follow the use-def chain if `currOpOperand` is defined by a LinalgOp.
OpOperand *currOpOperand = opOperand;
Expand Down Expand Up @@ -576,7 +573,7 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
rewriter, loc, operand, innerPackSizes, innerPos,
/*outerDimsPerm=*/{});
// TODO: value of the padding attribute should be determined by consumers.
Attribute zeroAttr =
auto zeroAttr =
rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
packOps.push_back(rewriter.create<tensor::PackOp>(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes) {
}

/// Return the identity numeric value associated to the give op.
std::optional<Attribute> getNeutralElement(Operation *op) {
std::optional<TypedAttr> getNeutralElement(Operation *op) {
// Builder only used as helper for attribute creation.
OpBuilder b(op->getContext());
Type resultType = op->getResult(0).getType();
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Math/IR/MathOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,5 +522,5 @@ OpFoldResult math::TruncOp::fold(FoldAdaptor adaptor) {
Operation *math::MathDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, value, type);
return arith::ConstantOp::materialize(builder, value, type, loc);
}
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ CbrtApproximation::matchAndRewrite(math::CbrtOp op,
floatTy = broadcast(floatTy, shape);
intTy = broadcast(intTy, shape);

auto bconst = [&](Attribute attr) -> Value {
auto bconst = [&](TypedAttr attr) -> Value {
Value value = b.create<arith::ConstantOp>(attr);
return broadcast(b, value, shape);
};
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ struct Wrapper {
Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, value, type);
return nullptr;
return arith::ConstantOp::materialize(builder, value, type, loc);
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
size =
rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
sizes[i] = sizeAttr;
}
strides[i] = stride;
if (i > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Attribute getScalarOrSplatAttr(Type type, int64_t value) {
if (auto intTy = type.dyn_cast<IntegerType>())
return IntegerAttr::get(intTy, sizedValue);

return SplatElementsAttr::get(type, sizedValue);
return SplatElementsAttr::get(cast<ShapedType>(type), sizedValue);
}

Value lowerExtendedMultiplication(Operation *mulOp, PatternRewriter &rewriter,
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,7 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
if (type.isa<WitnessType>())
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return nullptr;
return arith::ConstantOp::materialize(builder, value, type, loc);
}

LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}

mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
if (tp.isa<FloatType>())
return builder.getFloatAttr(tp, 1.0);
if (tp.isa<IndexType>())
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
/// for unsupported types we raise `llvm_unreachable` rather than
/// returning a null attribute.
Attribute getOneAttr(Builder &builder, Type tp);
TypedAttr getOneAttr(Builder &builder, Type tp);

/// Generates the comparison `v != 0` where `v` is of numeric type.
/// For floating types, we use the "unordered" comparator (i.e., returns
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ using namespace mlir::tensor;
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, value, type);
if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
return op;
if (complex::ConstantOp::isBuildableWith(value, type))
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ void VectorDialect::initialize() {
Operation *VectorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

IntegerType vector::getVectorSubscriptType(Builder &builder) {
Expand Down Expand Up @@ -1729,7 +1729,7 @@ class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
if (!splat)
return failure();
Attribute newAttr = splat.getSplatValue<Attribute>();
TypedAttr newAttr = splat.getSplatValue<TypedAttr>();
if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
Expand Down Expand Up @@ -1767,9 +1767,9 @@ class ExtractOpNonSplatConstantFolder final
copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
int64_t elemBeginPosition =
linearize(completePositions, computeStrides(vecTy.getShape()));
auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
auto denseValuesBegin = dense.value_begin<TypedAttr>() + elemBeginPosition;

Attribute newAttr;
TypedAttr newAttr;
if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
SmallVector<Attribute> elementValues(
denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const final {
MaskableOpInterface maskableOp = maskOp.getMaskableOp();
auto maskableOp = cast<MaskableOpInterface>(maskOp.getMaskableOp());
SourceOp sourceOp = dyn_cast<SourceOp>(maskableOp.getOperation());
if (!sourceOp)
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,8 @@ struct WarpOpConstant : public OpRewritePattern<WarpExecuteOnLane0Op> {
return failure();
unsigned operandIndex = yieldOperand->getOperandNumber();
Attribute scalarAttr = dense.getSplatValue<Attribute>();
Attribute newAttr = DenseElementsAttr::get(
warpOp.getResult(operandIndex).getType(), scalarAttr);
auto newAttr = DenseElementsAttr::get(
cast<ShapedType>(warpOp.getResult(operandIndex).getType()), scalarAttr);
Location loc = warpOp.getLoc();
rewriter.setInsertionPointAfter(warpOp);
Value distConstant = rewriter.create<arith::ConstantOp>(loc, newAttr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct MaskCompressOpConversion
src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
op.getConstantSrcAttr());
} else {
Attribute zeroAttr = rewriter.getZeroAttr(opType);
auto zeroAttr = rewriter.getZeroAttr(opType);
src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
return getArrayAttr(attrs);
}

Attribute Builder::getZeroAttr(Type type) {
TypedAttr Builder::getZeroAttr(Type type) {
if (type.isa<FloatType>())
return getFloatAttr(type, 0.0);
if (type.isa<IndexType>())
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/IR/BuiltinAttributeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//

Type ElementsAttr::getElementType(ElementsAttr elementsAttr) {
return elementsAttr.getType().getElementType();
return elementsAttr.getShapedType().getElementType();
}

int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) {
return elementsAttr.getType().getNumElements();
return elementsAttr.getShapedType().getNumElements();
}

bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
Expand All @@ -49,7 +49,7 @@ bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
}
bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr,
ArrayRef<uint64_t> index) {
return isValidIndex(elementsAttr.getType(), index);
return isValidIndex(elementsAttr.getShapedType(), index);
}

uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
elementType.getContext());

// Wrap AffineMap into Attribute.
Attribute layout = AffineMapAttr::get(map);
auto layout = AffineMapAttr::get(map);

// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
Expand All @@ -559,7 +559,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
elementType.getContext());

// Wrap AffineMap into Attribute.
Attribute layout = AffineMapAttr::get(map);
auto layout = AffineMapAttr::get(map);

// Drop default memory space value and replace it with empty attribute.
memorySpace = skipDefaultMemorySpace(memorySpace);
Expand All @@ -577,7 +577,7 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
elementType.getContext());

// Wrap AffineMap into Attribute.
Attribute layout = AffineMapAttr::get(map);
auto layout = AffineMapAttr::get(map);

// Convert deprecated integer-like memory space to Attribute.
Attribute memorySpace =
Expand All @@ -598,7 +598,7 @@ MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
elementType.getContext());

// Wrap AffineMap into Attribute.
Attribute layout = AffineMapAttr::get(map);
auto layout = AffineMapAttr::get(map);

// Convert deprecated integer-like memory space to Attribute.
Attribute memorySpace =
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(

// Fall back to element-by-element construction otherwise.
if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
assert(elementsAttr.getType().hasStaticShape());
assert(!elementsAttr.getType().getShape().empty() &&
assert(elementsAttr.getShapedType().hasStaticShape());
assert(!elementsAttr.getShapedType().getShape().empty() &&
"unexpected empty elements attribute shape");

SmallVector<llvm::Constant *, 8> constants;
Expand All @@ -422,7 +422,7 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
}
ArrayRef<llvm::Constant *> constantsRef = constants;
llvm::Constant *result = buildSequentialConstant(
constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc);
assert(constantsRef.empty() && "did not consume all elemental constants");
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {

spirv::SpecConstantOp
spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
Attribute defaultValue) {
TypedAttr defaultValue) {
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
defaultValue);
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class Deserializer {

/// Creates a spirv::SpecConstantOp.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
Attribute defaultValue);
TypedAttr defaultValue);

/// Processes the OpVariable instructions at current `offset` into `binary`.
/// It is expected that this method is used for variables that are to be
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Transforms/Utils/InliningUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, IRMapping &mapper,
ValueRange resultsToReplace, TypeRange regionResultTypes,
std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, Operation *call = nullptr) {
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
assert(resultsToReplace.size() == regionResultTypes.size());
// We expect the region to have at least one block.
if (src->empty())
Expand Down Expand Up @@ -328,7 +328,7 @@ static LogicalResult
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
Block::iterator inlinePoint, ValueRange inlinedOperands,
ValueRange resultsToReplace, std::optional<Location> inlineLoc,
bool shouldCloneInlinedRegion, Operation *call = nullptr) {
bool shouldCloneInlinedRegion, CallOpInterface call = {}) {
// We expect the region to have at least one block.
if (src->empty())
return failure();
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Transforms/ViewOpGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ class PrintOpPass : public impl::ViewOpGraphBase<PrintOpPass> {
// Elide "big" elements attributes.
auto elements = attr.dyn_cast<ElementsAttr>();
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getType().getRank(), '[') << "..."
<< std::string(elements.getType().getRank(), ']') << " : "
os << std::string(elements.getShapedType().getRank(), '[') << "..."
<< std::string(elements.getShapedType().getRank(), ']') << " : "
<< elements.getType();
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
auto *originalOp = info->originalProducer.getOperation();
auto *originalOpInLinalgOpsVector =
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
*originalOpInLinalgOpsVector = info->fusedProducer;
// Don't mark for erasure in the tensor case, let DCE handle this.
changed = true;
}
Expand Down
8 changes: 2 additions & 6 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
}

// Test support for ElementsAttrInterface.
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
ElementsAttrInterface, TypedAttrInterface
]> {
def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> {
let mnemonic = "i64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
Expand Down Expand Up @@ -269,9 +267,7 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
}

// Test simple extern 1D vector using ElementsAttrInterface.
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [
ElementsAttrInterface, TypedAttrInterface
]> {
def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ElementsAttrInterface]> {
let mnemonic = "e1di64_elements";
let parameters = (ins
AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,
Expand Down