Skip to content

Commit

Permalink
[Torch] support decomposition of aten.aminmax (#3513)
Browse files Browse the repository at this point in the history
* unify decompisition of `aten.amax` and `aten.amin`
* support `aten.amax` with `dim=()`
  • Loading branch information
qingyunqu committed Jun 29, 2024
1 parent f9fc741 commit 0e71a19
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 106 deletions.
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [
}];
}

def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchOptionalTensorType:$min,
AnyTorchOptionalTensorType:$max
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenAminmaxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}

def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
AllowsTypeRefinement,
ReadOnly
Expand Down
16 changes: 10 additions & 6 deletions lib/Conversion/TorchToStablehlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,18 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "non-const integer `dim` is not supported");
}
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
if (inputDims.size() == 0) {
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getRank()));
} else {
for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
// Drop invalid dims
if (isValidDim(d, inputTy.getRank())) {
dims.push_back(d);
}
}
llvm::sort(dims.begin(), dims.end());
}
llvm::sort(dims.begin(), dims.end());
SmallVector<int64_t> reduceResultShape =
getReduceOutputShape(inputTy.getShape(), dims);

Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.tuple<list<int>, list<int>>) {\n"
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %3 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" } else {\n"
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" torch.prim.If.yield %4 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
Expand Down Expand Up @@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.bool) -> !torch.tuple<int, int> {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
Expand Down
208 changes: 109 additions & 99 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
.getValues();
}

// Reduction function to calculate min along given `dim`.
static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType = cast<BaseTensorType>(computeReductionType(
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim));
if (!valueType)
return nullptr;
BaseTensorType indexType =
cast<BaseTensorType>(valueType.getWithSizesAndDtype(
!valueType.hasSizes() ? std::optional<ArrayRef<int64_t>>()
: llvm::ArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed)));
return rewriter
.create<AtenMinDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.getValues();
}

// Helper for creating `aten::sub_tensor_op`.
static Value createTensorSub(PatternRewriter &rewriter, Location loc,
Type tensorType, Value lhs, Value rhs) {
Expand Down Expand Up @@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
return out;
}

namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
/// Eg:
/// INPUT:
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
///
/// OUTPUT:
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
///
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
/// of the `aten.amax` op and create an `aten.amax.dim` op.
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
/// previous `aten.amax.dim` op.
class DecomposeAtenAmaxOp : public OpRewritePattern<AtenAmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<int64_t, 4> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))

return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");

bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for keepDim");

Value input = op.getSelf();
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
}
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.rbegin(), dims.rend());
for (int64_t dimInt : dims) {
int64_t inputRank = inputTy.getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
// The input to the next invocation of aten.max.dim is the output of the
// previous aten.max.dim op.
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
}
rewriter.replaceOp(op, input);
return success();
}
};
} // end namespace

namespace {
class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
public:
Expand Down Expand Up @@ -1880,52 +1840,69 @@ class DecomposeAten_LogSoftmaxBackwardDataOp
} // namespace

namespace {
class DecomposeAtenAMinMaxOp : public OpRewritePattern<Torch::AtenAminOp> {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed.
/// Eg:
/// INPUT:
/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False)
///
/// OUTPUT:
/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1
/// input_2 = aten.max.dim(input_1, 1, keepdim) #2
/// final_output = aten.max.dim(input_2, 0, keepdim) #3
///
/// NOTE: We iterate over, in reverse order, every dimension included in `dim`
/// of the `aten.amax` op and create an `aten.amax.dim` op.
/// Input tensor to the next `aten.amax.dim` op is thus the output of the
/// previous `aten.amax.dim` op.
template <typename OpTy, typename DecompOpTy>
class DecomposeAtenAminAmaxOp : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<Torch::AtenAminOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Torch::AtenAminOp op,
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
llvm::SmallVector<int64_t> dimList;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
return rewriter.notifyMatchFailure(op, "dims not foldable constants");
Location loc = op.getLoc();
bool keepDim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for keepDim");

Value input = op.getSelf();
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
if (!inputTy || !inputTy.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
}

bool keepdim;
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) {
return rewriter.notifyMatchFailure(op, "keepdims not foldable constants");
SmallVector<int64_t, 4> dims;
if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims)))
return rewriter.notifyMatchFailure(op,
"non-const dim parameter unsupported");
if (dims.size() == 0) {
dims = llvm::to_vector(llvm::seq<int64_t>(0, inputTy.getSizes().size()));
}

auto loc = op.getLoc();
std::sort(dimList.begin(), dimList.end(), std::greater<int64_t>());

Value reduction = op.getSelf();
auto resultTy = cast<Torch::ValueTensorType>(op.getType());
auto reductionTy = cast<Torch::ValueTensorType>(reduction.getType());
llvm::SmallVector<int64_t> reductionShape(reductionTy.getSizes());

for (auto dim : dimList) {
auto dimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim));
reductionShape[dim] = 1;
if (!keepdim) {
for (int i = dim, s = reductionShape.size() - 1; i < s; ++i)
reductionShape[i] = reductionShape[i + 1];
reductionShape.resize(reductionShape.size() - 1);
// For every dimension included in `dim` of the op, iterated over in
// reverse order, we create a call to aten.max.dim.
std::sort(dims.rbegin(), dims.rend());
for (int64_t dimInt : dims) {
int64_t inputRank = inputTy.getSizes().size();
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimInt));
// The input to the next invocation of aten.max.dim is the output of the
// previous aten.max.dim op.
static_assert(std::is_same_v<OpTy, AtenAmaxOp> ||
std::is_same_v<OpTy, AtenAminOp>);
if (std::is_same_v<OpTy, AtenAmaxOp>) {
input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim);
} else if (std::is_same_v<OpTy, AtenAminOp>) {
input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim);
}

reductionTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, resultTy.getOptionalDtype());
auto idxTy = rewriter.getType<Torch::ValueTensorType>(
reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true));
llvm::SmallVector<Type, 2> types{reductionTy, idxTy};

reduction = rewriter
.create<Torch::AtenMinDimOp>(loc, types, reduction,
dimValue, op.getKeepdim())
.getResult(0);
}

rewriter.replaceOp(op, reduction);
rewriter.replaceOp(op, input);
return success();
}
};
Expand Down Expand Up @@ -1987,6 +1964,36 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
};
} // namespace

// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp`
namespace {
class DecomposeAtenAminmaxOp : public OpRewritePattern<AtenAminmaxOp> {
public:
using OpRewritePattern<AtenAminmaxOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAminmaxOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Torch::ListType listType =
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>());
Value dimList;
if (isa<Torch::NoneType>(op.getDim().getType())) {
dimList = rewriter.create<Torch::PrimListConstructOp>(loc, listType,
ArrayRef<Value>{});
} else {
dimList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, ArrayRef<Value>{op.getDim()});
}

auto amin = rewriter.create<AtenAminOp>(
loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim());
auto amax = rewriter.create<AtenAmaxOp>(
loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim());
rewriter.replaceOp(op, {amin, amax});
return success();
}
};
} // namespace

// Decompose `aten.bucketize` into the following op sequence:
//
// def aten_bucketize(input, boundaries, out_int32, right):
Expand Down Expand Up @@ -8598,7 +8605,6 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenAddmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAMinMaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
Expand Down Expand Up @@ -8631,10 +8637,15 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsIotaOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinspaceOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAminAmaxOp<AtenAmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenAminAmaxOp<AtenAminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAminmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
Expand Down Expand Up @@ -8707,7 +8718,6 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenNumpyTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLinspaceOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenArgminOp>();
target.addIllegalOp<AtenAminmaxOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenAminOp>();
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();
Expand Down Expand Up @@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenVarCorrectionOp>();
target.addIllegalOp<AtenStdDimOp>();
target.addIllegalOp<AtenStdCorrectionOp>();
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,9 @@
}

STABLEHLO_PASS_SET = {
"ReduceAminmaxSingleDim_basic",
"ReduceAminmaxAllDims_basic",
"ReduceAmaxEmptyDim_basic",
"ReduceMinAlongDimNegative_basic",
"ReduceMinAlongDim_basic",
"ArgminModule_with_dim",
Expand Down
Loading

0 comments on commit 0e71a19

Please sign in to comment.