diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9525f9f9ffa6..821ae33b6876 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8519,6 +8519,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [ }]; } +def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsneginfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsposinfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, @@ -10449,6 +10495,32 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ }]; } +def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$nan, + AnyTorchOptionalFloatType:$posinf, + AnyTorchOptionalFloatType:$neginf + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNanToNumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c286168080e4..25f5d4618420 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6702,6 +6702,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7874,6 +7882,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" @@ -9710,6 +9722,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -10713,6 +10771,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0a3ce2ea7797..9c4776231cc8 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -932,6 +932,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenIsneginfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsneginfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenIsposinfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsposinfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(f64Type, + APFloat::getInf(f64Type.getFloatSemantics()))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -2471,6 +2505,49 @@ class DecomposeAtenWhereScalarSelfOp }; } // namespace +namespace { +class DecomposeAtenNanToNumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNanToNumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + mlir::FloatType f64Type = rewriter.getF64Type(); + Value nan = op.getNan(); + Value posinf = op.getPosinf(); + Value neginf = op.getNeginf(); + auto baseType = + ValueTensorType::getWithLeastStaticInformation(op.getContext()); + if (dyn_cast_or_null(nan.getDefiningOp())) + nan = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(posinf.getDefiningOp())) + posinf = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(neginf.getDefiningOp())) + neginf = rewriter.create( + loc, + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + Value isNan = + rewriter.create(loc, baseType, op.getSelf()); + Value where = rewriter.create( + loc, baseType, isNan, nan, op.getSelf()); + Value isposinf = + rewriter.create(loc, baseType, where); + where = rewriter.create( + loc, baseType, isposinf, posinf, where); + Value isneginf = + rewriter.create(loc, baseType, where); + rewriter.replaceOpWithNewOp( + op, op.getType(), isneginf, neginf, where); + return success(); + } +}; +} // namespace + // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp @@ -6393,6 +6470,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6448,6 +6526,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 933140d3013d..e76adb9b89dc 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -431,8 +431,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e04657df4d2c..8a440c16b882 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -473,6 +473,7 @@ "ElementwiseAtenWhereSelfModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarSelfStaticModule_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", @@ -1039,6 +1040,8 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalOrOpBrodcastModule_basic", "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", @@ -1090,6 +1093,8 @@ "ElementwiseGtIntTensorModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIsinfModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseIsnanModule_basic", "ElementwiseLeFloatTensorModule_basic", "ElementwiseLeIntTensorModule_basic", @@ -1146,6 +1151,7 @@ "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "FlattenRank0Module_basic", @@ -1511,6 +1517,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic" } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6a8fbf34e911..efddfeed1d78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -337,6 +337,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]: def aten〇isinf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇isneginf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇isposinf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1058,6 +1064,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(condition, other) +def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) @@ -2516,6 +2525,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -3247,6 +3270,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel dtypes = [get_dtype_of_scalar(self), other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 249c25628a82..3637e679695d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -566,6 +566,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") + emit("aten::isneginf : (Tensor) -> (Tensor)") + emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") @@ -640,6 +642,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") + emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 23a22142c4d5..9b857839db2b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -339,6 +339,33 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True) + ]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + +@register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) +def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): + module.forward(torch.tensor( + [ + [float('nan'), 0.0, float('nan'), 0.0], + [float('inf'), 0.0, float('inf'), 0.0], + [float('-inf'), 0.0, float('-inf'), 0.0] + ] + )) + + +# ============================================================================== + + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): @@ -3463,6 +3490,58 @@ def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtenIsneginfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isneginf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsneginfOpModule()) +def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + +class ElementwiseAtenIsposinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isposinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) +def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module): def __init__(self): super().__init__()