From 0e71a192d82fdfcfe5d3eb90882d9f07eca077ae Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sat, 29 Jun 2024 21:44:05 +0800 Subject: [PATCH] [Torch] support decomposition of aten.aminmax (#3513) * unify decompisition of `aten.amax` and `aten.amin` * support `aten.amax` with `dim=()` --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++ lib/Conversion/TorchToStablehlo/Reduction.cpp | 16 +- .../Transforms/AbstractInterpLibrary.cpp | 21 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 208 +++++++++--------- .../Transforms/LowerToBackendContract.cpp | 4 +- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 12 + .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 69 ++++++ 9 files changed, 254 insertions(+), 106 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f4223b1f4bf..9428e749b5f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11463,6 +11463,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [ }]; } +def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchOptionalTensorType:$min, + AnyTorchOptionalTensorType:$max + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenAminmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c9a2ad2e7ff..bc77a860ade 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -488,14 +488,18 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { return rewriter.notifyMatchFailure( op, "non-const integer `dim` is not supported"); } - for (auto d : inputDims) { - d = toPositiveDim(d, inputTy.getRank()); - // Drop invalid dims - if (isValidDim(d, inputTy.getRank())) { - dims.push_back(d); + if (inputDims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getRank())); + } else { + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } } + llvm::sort(dims.begin(), dims.end()); } - llvm::sort(dims.begin(), dims.end()); SmallVector reduceResultShape = getReduceOutputShape(inputTy.getShape(), dims); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f8a5409b8a7..8bf50fd21cc 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7371,6 +7371,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.tuple, list>) {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.prim.TupleConstruct %2, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %4 : !torch.tuple, list>\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %4 = torch.prim.TupleConstruct %3, %3 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %4 : !torch.tuple, list>\n" +" }\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -13568,6 +13584,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2086fb68afa..36e79736381 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -113,6 +113,25 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, .getValues(); } +// Reduction function to calculate min along given `dim`. +static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, + Operation *op, Value input, Value dim, + bool keepDim) { + Value keepDimCst = rewriter.create(loc, keepDim); + BaseTensorType valueType = cast(computeReductionType( + rewriter, op, cast(input.getType()), dim, keepDim)); + if (!valueType) + return nullptr; + BaseTensorType indexType = + cast(valueType.getWithSizesAndDtype( + !valueType.hasSizes() ? std::optional>() + : llvm::ArrayRef(valueType.getSizes()), + IntegerType::get(op->getContext(), 64, IntegerType::Signed))); + return rewriter + .create(loc, valueType, indexType, input, dim, keepDimCst) + .getValues(); +} + // Helper for creating `aten::sub_tensor_op`. static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { @@ -605,65 +624,6 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, return out; } -namespace { -/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the -/// number of dimensions across which the max needs to be computed. -/// Eg: -/// INPUT: -/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) -/// -/// OUTPUT: -/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 -/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 -/// final_output = aten.max.dim(input_2, 0, keepdim) #3 -/// -/// NOTE: We iterate over, in reverse order, every dimension included in `dim` -/// of the `aten.amax` op and create an `aten.amax.dim` op. -/// Input tensor to the next `aten.amax.dim` op is thus the output of the -/// previous `aten.amax.dim` op. -class DecomposeAtenAmaxOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAmaxOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) - - return rewriter.notifyMatchFailure(op, - "non-const dim parameter unsupported"); - - bool keepDim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) - return rewriter.notifyMatchFailure( - op, "Expected a constant boolean value for keepDim"); - - Value input = op.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy || !inputTy.hasSizes()) { - return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); - } - // For every dimension included in `dim` of the op, iterated over in - // reverse order, we create a call to aten.max.dim. - std::sort(dims.rbegin(), dims.rend()); - for (int64_t dimInt : dims) { - int64_t inputRank = inputTy.getSizes().size(); - dimInt = toPositiveDim(dimInt, inputRank); - if (!isValidDim(dimInt, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - // The input to the next invocation of aten.max.dim is the output of the - // previous aten.max.dim op. - input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); - } - rewriter.replaceOp(op, input); - return success(); - } -}; -} // end namespace - namespace { class DecomposeAtenTriuOp : public OpRewritePattern { public: @@ -1880,52 +1840,69 @@ class DecomposeAten_LogSoftmaxBackwardDataOp } // namespace namespace { -class DecomposeAtenAMinMaxOp : public OpRewritePattern { +/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the +/// number of dimensions across which the max needs to be computed. +/// Eg: +/// INPUT: +/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) +/// +/// OUTPUT: +/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 +/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 +/// final_output = aten.max.dim(input_2, 0, keepdim) #3 +/// +/// NOTE: We iterate over, in reverse order, every dimension included in `dim` +/// of the `aten.amax` op and create an `aten.amax.dim` op. +/// Input tensor to the next `aten.amax.dim` op is thus the output of the +/// previous `aten.amax.dim` op. +template +class DecomposeAtenAminAmaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Torch::AtenAminOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - llvm::SmallVector dimList; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { - return rewriter.notifyMatchFailure(op, "dims not foldable constants"); + Location loc = op.getLoc(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for keepDim"); + + Value input = op.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); } - bool keepdim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { - return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, + "non-const dim parameter unsupported"); + if (dims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getSizes().size())); } - auto loc = op.getLoc(); - std::sort(dimList.begin(), dimList.end(), std::greater()); - - Value reduction = op.getSelf(); - auto resultTy = cast(op.getType()); - auto reductionTy = cast(reduction.getType()); - llvm::SmallVector reductionShape(reductionTy.getSizes()); - - for (auto dim : dimList) { - auto dimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); - reductionShape[dim] = 1; - if (!keepdim) { - for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) - reductionShape[i] = reductionShape[i + 1]; - reductionShape.resize(reductionShape.size() - 1); + // For every dimension included in `dim` of the op, iterated over in + // reverse order, we create a call to aten.max.dim. + std::sort(dims.rbegin(), dims.rend()); + for (int64_t dimInt : dims) { + int64_t inputRank = inputTy.getSizes().size(); + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + // The input to the next invocation of aten.max.dim is the output of the + // previous aten.max.dim op. + static_assert(std::is_same_v || + std::is_same_v); + if (std::is_same_v) { + input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); + } else if (std::is_same_v) { + input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim); } - - reductionTy = rewriter.getType( - reductionShape, resultTy.getOptionalDtype()); - auto idxTy = rewriter.getType( - reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); - llvm::SmallVector types{reductionTy, idxTy}; - - reduction = rewriter - .create(loc, types, reduction, - dimValue, op.getKeepdim()) - .getResult(0); } - - rewriter.replaceOp(op, reduction); + rewriter.replaceOp(op, input); return success(); } }; @@ -1987,6 +1964,36 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { }; } // namespace +// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp` +namespace { +class DecomposeAtenAminmaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAminmaxOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Torch::ListType listType = + rewriter.getType(rewriter.getType()); + Value dimList; + if (isa(op.getDim().getType())) { + dimList = rewriter.create(loc, listType, + ArrayRef{}); + } else { + dimList = rewriter.create( + loc, listType, ArrayRef{op.getDim()}); + } + + auto amin = rewriter.create( + loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim()); + auto amax = rewriter.create( + loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim()); + rewriter.replaceOp(op, {amin, amax}); + return success(); + } +}; +} // namespace + // Decompose `aten.bucketize` into the following op sequence: // // def aten_bucketize(input, boundaries, out_int32, right): @@ -8598,7 +8605,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8631,10 +8637,15 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8707,7 +8718,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 15bebfc6439..5e83c585ae8 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -438,6 +438,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -502,7 +505,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6ac3ae099a7..8272bc4b069 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -830,6 +830,9 @@ } STABLEHLO_PASS_SET = { + "ReduceAminmaxSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAmaxEmptyDim_basic", "ReduceMinAlongDimNegative_basic", "ReduceMinAlongDim_basic", "ArgminModule_with_dim", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 9052c8cc205..6e4957e5889 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -722,6 +722,13 @@ def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = Fa def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: + if dim is None: + return [], [] + else: + reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) + return reduced_shape, reduced_shape + def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) @@ -4524,6 +4531,11 @@ def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇min〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇aminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c3cb95dd7fb..8e6745ea4a5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -841,6 +841,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") + emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)") emit( "aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 347a1f8cc25..7cf6dd69445 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1204,6 +1204,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAmaxEmptyDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amax(a, dim=()) + + +@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim()) +def ReduceAmaxEmptyDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceAmaxOutOfOrderDim(torch.nn.Module): def __init__(self): super().__init__() @@ -1273,6 +1296,52 @@ def ReduceAminSingleDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAminmaxSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a, dim=1) + + +@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim()) +def ReduceAminmaxSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + +class ReduceAminmaxAllDims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a) + + +@register_test_case(module_factory=lambda: ReduceAminmaxAllDims()) +def ReduceAminmaxAllDims_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceMinFloatModule(torch.nn.Module): def __init__(self): super().__init__()