Skip to content

Commit

Permalink
[mlir][NFC] Remove a few op builders that simply swap parameter order
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D119093
  • Loading branch information
River707 committed Feb 8, 2022
1 parent d7f0083 commit 3c69bc4
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 50 deletions.
6 changes: 0 additions & 6 deletions mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
}]>
];

let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
}

Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
}]>
];

let extraClassDeclaration = [{
/// Fold the given CastOp into consumer op.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,11 +1003,11 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, value, IntegerType::get(rewriter.getContext(), 64));
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::None:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ struct IndexCastOpInterface
getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
layout, sourceType.getMemorySpace());

replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source,
resultType);
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
source);
return success();
}
};
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,15 +835,15 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
scalingFactor);
}
Value numWorkersIndex =
b.create<arith::IndexCastOp>(numWorkerThreadsVal, b.getI32Type());
b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal);
Value numWorkersFloat =
b.create<arith::SIToFPOp>(numWorkersIndex, b.getF32Type());
b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex);
Value scaledNumWorkers =
b.create<arith::MulFOp>(scalingFactor, numWorkersFloat);
Value scaledNumInt =
b.create<arith::FPToSIOp>(scaledNumWorkers, b.getI32Type());
b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers);
Value scaledWorkers =
b.create<arith::IndexCastOp>(scaledNumInt, b.getIndexType());
b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt);

Value maxComputeBlocks = b.create<arith::MaxSIOp>(
b.create<arith::ConstantIndexOp>(1), scaledWorkers);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
auto i32Vec = broadcast(builder.getI32Type(), shape);

// exp2(k)
Value k = builder.create<arith::FPToSIOp>(kF32, i32Vec);
Value k = builder.create<arith::FPToSIOp>(i32Vec, kF32);
Value exp2KValue = exp2I32(builder, k);

// exp(x) = exp(y) * exp2(k)
Expand Down Expand Up @@ -1042,7 +1042,7 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(

auto i32Vec = broadcast(builder.getI32Type(), shape);
auto fPToSingedInteger = [&](Value a) -> Value {
return builder.create<arith::FPToSIOp>(a, i32Vec);
return builder.create<arith::FPToSIOp>(i32Vec, a);
};

auto modulo4 = [&](Value a) -> Value {
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
alloc.alignmentAttr());
// Insert a cast so we have the same type as the old alloc.
auto resultCast =
rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);

rewriter.replaceOp(alloc, {resultCast});
return success();
Expand Down Expand Up @@ -2156,8 +2156,8 @@ class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
rewriter.replaceOp(subViewOp, subViewOp.source());
return success();
}
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.source(),
subViewOp.getType());
rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
subViewOp.source());
return success();
}
};
Expand All @@ -2177,7 +2177,7 @@ struct SubViewReturnTypeCanonicalizer {
/// A canonicalizer wrapper to replace SubViewOps.
struct SubViewCanonicalizer {
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
}
};

Expand Down Expand Up @@ -2422,7 +2422,7 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
viewOp.getOperand(0),
viewOp.byte_shift(), newOperands);
// Insert a cast so we have the same type as the old memref type.
rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
return success();
}
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.shape(), index);
if (!size.getType().isa<IndexType>())
size = rewriter.create<arith::IndexCastOp>(loc, size,
rewriter.getIndexType());
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
Value val = rewriter.create<tensor::ExtractOp>(loc, indices,
ValueRange{ivs[0], idx});
val =
rewriter.create<arith::IndexCastOp>(loc, val, rewriter.getIndexType());
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
rewriter.create<memref::StoreOp>(loc, val, ind, idx);
}
return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,11 +831,11 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
if (!etp.isa<IndexType>()) {
if (etp.getIntOrFloatBitWidth() < 32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vload, vectorType(codegen, rewriter.getI32Type()));
loc, vectorType(codegen, rewriter.getI32Type()), vload);
else if (etp.getIntOrFloatBitWidth() < 64 &&
!codegen.options.enableSIMDIndex32)
vload = rewriter.create<arith::ExtUIOp>(
loc, vload, vectorType(codegen, rewriter.getI64Type()));
loc, vectorType(codegen, rewriter.getI64Type()), vload);
}
return vload;
}
Expand All @@ -846,9 +846,9 @@ static Value genLoad(CodeGen &codegen, PatternRewriter &rewriter, Location loc,
Value load = rewriter.create<memref::LoadOp>(loc, ptr, s);
if (!load.getType().isa<IndexType>()) {
if (load.getType().getIntOrFloatBitWidth() < 64)
load = rewriter.create<arith::ExtUIOp>(loc, load, rewriter.getI64Type());
load = rewriter.create<arith::ExtUIOp>(loc, rewriter.getI64Type(), load);
load =
rewriter.create<arith::IndexCastOp>(loc, load, rewriter.getIndexType());
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), load);
}
return load;
}
Expand All @@ -868,7 +868,7 @@ static Value genAddress(CodeGen &codegen, PatternRewriter &rewriter,
Value mul = rewriter.create<arith::MulIOp>(loc, size, p);
if (auto vtp = i.getType().dyn_cast<VectorType>()) {
Value inv =
rewriter.create<arith::IndexCastOp>(loc, mul, vtp.getElementType());
rewriter.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul);
mul = genVectorInvariantValue(codegen, rewriter, inv);
}
return rewriter.create<arith::AddIOp>(loc, mul, i);
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,25 +671,25 @@ Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
rewriter.getZeroAttr(v0.getType())),
v0);
case kTruncF:
return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
case kExtF:
return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
case kCastFS:
return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
case kCastFU:
return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
case kCastSF:
return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
case kCastUF:
return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
case kCastS:
return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
case kCastU:
return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
case kMulF:
return rewriter.create<arith::MulFOp>(loc, v0, v1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
Expand All @@ -271,7 +271,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted =
b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
Expand Down Expand Up @@ -309,7 +309,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(),
xferOp.indices().end());
Expand All @@ -324,7 +324,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
loc, MemRefType::get({}, vector.getType()), alloc));

Value casted =
b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
Expand Down Expand Up @@ -360,7 +360,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
[&](OpBuilder &b, Location loc) {
Value res = memref;
if (compatibleMemRefType != xferOp.getShapedType())
res = b.create<memref::CastOp>(loc, memref, compatibleMemRefType);
res = b.create<memref::CastOp>(loc, compatibleMemRefType, memref);
scf::ValueVector viewAndIndices{res};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.indices().begin(),
Expand All @@ -369,7 +369,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
},
[&](OpBuilder &b, Location loc) {
Value casted =
b.create<memref::CastOp>(loc, alloc, compatibleMemRefType);
b.create<memref::CastOp>(loc, compatibleMemRefType, alloc);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);
Expand Down

0 comments on commit 3c69bc4

Please sign in to comment.