diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 60d631e82aa3..a5238c9b1211 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -341,10 +341,14 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, isa(op)) return b.create(loc, b.getZeroAttr(elementType)); - if (isa(op)) { + if (isa(op)) { return b.create(loc, b.getBoolAttr(true)); } + if (isa(op)) { + return b.create(loc, b.getBoolAttr(false)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -439,11 +443,16 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); - } else if (isa(op)) { + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); + } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self, result); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -510,13 +519,13 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa( - op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and `AtenMinOp` each reduce - // along all the dimensions of the input tensor. + // `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and + // `AtenMinOp` each reduce along all the dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); @@ -715,6 +724,8 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); patterns.add>(typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 1e494f4337c0..fee5cc01e4ae 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -104,6 +104,18 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + + if (isa(op)) { + auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + return rewriter.create(op->getLoc(), constType, + constAttr); + } + op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); return nullptr; @@ -463,6 +475,150 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace +// AtenAllOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAllOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value allResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), allResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + +// AtenAnyOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAnyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenAllOp to StableHLO"); + } + auto outTy = getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (inputElemTy != outTy.getElementType()) { + // Use output bool type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = input.getType().dyn_cast(); + inputElemTy = inputTy.getElementType(); + } + + SmallVector dims; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + dims.push_back(i); + } + + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return failure(); + llvm::sort(dims.begin(), dims.end()); + auto stablehloReduceOp = rewriter.create( + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = stablehloReduceOp.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value anyResult = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + rewriter.create(op->getLoc(), anyResult); + } + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), + stablehloReduceOp.getResults()); + return success(); +} +} // namespace + // AtenProdOp namespace { template <> @@ -1052,6 +1208,8 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1f16a25a9555..9f900bc80cad 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1227,6 +1227,12 @@ "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceAnyBoolModule_basic", "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", @@ -1809,6 +1815,8 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAnyBoolModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", "ReduceSumDimIntListKeepDimFloatModule_basic", @@ -2715,6 +2723,7 @@ "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", + "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 5fe2db5ff441..076dd4e458a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -124,6 +124,120 @@ def ReduceProdElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAllFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllFloatModule()) +def ReduceAllFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAllIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllIntModule()) +def ReduceAllIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAllBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a) + + +@register_test_case(module_factory=lambda: ReduceAllBoolModule()) +def ReduceAllBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + +class ReduceAnyFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyFloatModule()) +def ReduceAnyFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class ReduceAnyIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyIntModule()) +def ReduceAnyIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, high=2).to(torch.int32)) + +# ============================================================================== + +class ReduceAnyBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.any(a) + + +@register_test_case(module_factory=lambda: ReduceAnyBoolModule()) +def ReduceAnyBoolModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=2).to(torch.bool)) + +# ============================================================================== + class ReduceSumDimIntListFloatModule(torch.nn.Module): def __init__(self): super().__init__()