From 2ded5cc7df7c639a59ee04e215eb5ff3d2cfb08f Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle Date: Tue, 9 Jan 2024 23:59:50 +0800 Subject: [PATCH 1/2] [Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 72 +++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 62 ++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 80 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 3 + projects/pt1/e2e_testing/xfail_sets.py | 7 ++ .../build_tools/abstract_interp_lib_gen.py | 29 +++++++ .../build_tools/torch_ods_gen.py | 3 + .../test_suite/elementwise.py | 79 ++++++++++++++++++ 8 files changed, 335 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 74a2e2327d1b..04e59017974a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8472,6 +8472,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [ }]; } +def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsneginfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsposinfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, @@ -10402,6 +10448,32 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ }]; } +def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$nan, + AnyTorchOptionalFloatType:$posinf, + AnyTorchOptionalFloatType:$neginf + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNanToNumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 55b9638dd0cc..d3219bbf7ee4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6678,6 +6678,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7850,6 +7858,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" @@ -9681,6 +9693,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -10684,6 +10742,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0a3ce2ea7797..79e3f839d7c2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -932,6 +932,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenIsneginfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsneginfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenIsposinfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsposinfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(f64Type, + APFloat::getInf(f64Type.getFloatSemantics()))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -2471,6 +2505,49 @@ class DecomposeAtenWhereScalarSelfOp }; } // namespace +namespace { +class DecomposeAtenNanToNumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNanToNumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + mlir::FloatType f64Type = rewriter.getF64Type(); + Value lan = op.getNan(); + Value posinf = op.getPosinf(); + Value neginf = op.getNeginf(); + auto baseType = + ValueTensorType::getWithLeastStaticInformation(op.getContext()); + if (dyn_cast_or_null(lan.getDefiningOp())) + lan = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(posinf.getDefiningOp())) + posinf = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(neginf.getDefiningOp())) + neginf = rewriter.create( + loc, + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + Value isNan = + rewriter.create(loc, baseType, op.getSelf()); + Value where = rewriter.create( + loc, baseType, isNan, lan, op.getSelf()); + Value isposinf = + rewriter.create(loc, baseType, where); + where = rewriter.create( + loc, baseType, isposinf, posinf, where); + Value isneginf = + rewriter.create(loc, baseType, where); + rewriter.replaceOpWithNewOp( + op, op.getType(), isneginf, neginf, where); + return success(); + } +}; +} // namespace + // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp @@ -6393,6 +6470,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6448,6 +6526,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 933140d3013d..e76adb9b89dc 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -431,8 +431,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ddb4865ec535..465a061d168e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -469,6 +469,7 @@ "ElementwiseAtenWhereSelfModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarSelfStaticModule_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", @@ -1035,6 +1036,8 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1074,6 +1077,8 @@ "ElementwiseGtIntTensorModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIsinfModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseIsnanModule_basic", "ElementwiseLeFloatTensorModule_basic", "ElementwiseLeIntTensorModule_basic", @@ -1130,6 +1135,7 @@ "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "FlattenRank0Module_basic", @@ -1494,4 +1500,5 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseNanToNumModule_Basic" } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a16d778c79a7..fdc607b3187c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -319,6 +319,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]: def aten〇isinf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇isneginf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇isposinf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1040,6 +1046,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(condition, other) +def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) @@ -2493,6 +2502,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -3224,6 +3247,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel dtypes = [get_dtype_of_scalar(self), other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9c0a0759b443..8a6bb13caf2a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -565,6 +565,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") + emit("aten::isneginf : (Tensor) -> (Tensor)") + emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") @@ -639,6 +641,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") + emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c18c9103d888..997445ac3f3f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -339,6 +339,33 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True) + ]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + +@register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) +def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): + module.forward(torch.tensor( + [ + [float('nan'), 0.0, float('nan'), 0.0], + [float('inf'), 0.0, float('inf'), 0.0], + [float('-inf'), 0.0, float('-inf'), 0.0] + ] + )) + + +# ============================================================================== + + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): @@ -3441,6 +3468,58 @@ def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtenIsneginfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isneginf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsneginfOpModule()) +def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('-inf'), float('nan')], + [1, float('-inf'), float('-inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + +class ElementwiseAtenIsposinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isposinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) +def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('inf'), 2, float('inf'), float('nan')], + [1, float('inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module): def __init__(self): super().__init__() From 19b970c35c08563bf670b3a187b2ae95499ce2d5 Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle Date: Sat, 13 Jan 2024 10:15:35 +0800 Subject: [PATCH 2/2] fix nit and tests. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 8 ++++---- .../python/torch_mlir_e2e_test/test_suite/elementwise.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 79e3f839d7c2..9c4776231cc8 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2513,13 +2513,13 @@ class DecomposeAtenNanToNumOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); mlir::FloatType f64Type = rewriter.getF64Type(); - Value lan = op.getNan(); + Value nan = op.getNan(); Value posinf = op.getPosinf(); Value neginf = op.getNeginf(); auto baseType = ValueTensorType::getWithLeastStaticInformation(op.getContext()); - if (dyn_cast_or_null(lan.getDefiningOp())) - lan = rewriter.create( + if (dyn_cast_or_null(nan.getDefiningOp())) + nan = rewriter.create( loc, rewriter.getFloatAttr( f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); if (dyn_cast_or_null(posinf.getDefiningOp())) @@ -2534,7 +2534,7 @@ class DecomposeAtenNanToNumOp : public OpRewritePattern { Value isNan = rewriter.create(loc, baseType, op.getSelf()); Value where = rewriter.create( - loc, baseType, isNan, lan, op.getSelf()); + loc, baseType, isNan, nan, op.getSelf()); Value isposinf = rewriter.create(loc, baseType, where); where = rewriter.create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 29dd497536d1..62759e6b0a52 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3506,8 +3506,8 @@ def forward(self, x): def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): test_input = torch.tensor( [ - [1, float('-inf'), 2, float('-inf'), float('nan')], - [1, float('-inf'), float('-inf'), float('nan'), 3], + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], ] ) module.forward(test_input) @@ -3532,8 +3532,8 @@ def forward(self, x): def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): test_input = torch.tensor( [ - [1, float('inf'), 2, float('inf'), float('nan')], - [1, float('inf'), float('inf'), float('nan'), 3], + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], ] ) module.forward(test_input)