From cee9a501c784efd3e753f140b41d504df2e16d77 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Mon, 29 Apr 2024 00:08:43 +0800 Subject: [PATCH 1/6] support rrelu checkpoint --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 147 ++++++++++++------ .../build_tools/abstract_interp_lib_gen.py | 9 ++ .../build_tools/torch_ods_gen.py | 3 +- 3 files changed, 111 insertions(+), 48 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4a234307d00..8690ad6f21b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -256,6 +256,106 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [ }]; } +def Torch_AtenRreluOp : Torch_Op<"aten.rrelu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRreluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_ : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRrelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRrelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ AllowsTypeRefinement, HasValueSemantics, @@ -4810,53 +4910,6 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ }]; } -def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenCeluOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$alpha - ); - let results = (outs - AnyTorchOptionalNonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenCelu_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_AtenRealOp : Torch_Op<"aten.real", [ AllowsTypeRefinement, ReadOnly diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index da486fe4602..8e9075f5968 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -529,6 +529,9 @@ def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu〡shape(self: List[int], lower: float, upper: float, training: bool) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2660,6 +2663,12 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.complex128, *all_integer_dtypes()})) +def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.333, training: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6e449c27745..3d8d4cf929c 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -301,6 +301,8 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", + "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sinh : (Tensor) -> (Tensor)", @@ -472,7 +474,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") From f4593cd23e8d8ee3c68b954e0f5073d0b03ad9a7 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 16:20:21 +0000 Subject: [PATCH 2/6] add rrelu defination --- .../Transforms/AbstractInterpLibrary.cpp | 24 +++++++++++++++++++ .../build_tools/abstract_interp_lib_gen.py | 8 +++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 553a8dc7441..e642ab0d753 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7002,6 +7002,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10488,6 +10492,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8e9075f5968..f22db1c4991 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -529,7 +529,7 @@ def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇rrelu〡shape(self: List[int], lower: float, upper: float, training: bool) -> List[int]: +def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) def aten〇selu〡shape(self: List[int]) -> List[int]: @@ -2663,10 +2663,10 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.complex128, *all_integer_dtypes()})) -def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.333, training: bool = False) -> int: +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()})) +def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype - assert is_float_dtype(self_dtype) + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) From 49d4bbb8ac2c5a07161a39679c189af70a1336ee Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 16:57:50 +0000 Subject: [PATCH 3/6] fix and add test --- .../Torch/Transforms/DecomposeComplexOps.cpp | 73 +++++++++++++++ .../test_suite/elementwise.py | 92 +++++++++++++++++++ 2 files changed, 165 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 677ccc4f241..da71cf99cb2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2415,6 +2415,79 @@ class DecomposeAtenPreluOp : public OpRewritePattern { } // namespace +// rrelu = max(0, x) + min(0, alpha * (exp(x) - 1)) +// if in training mode, the alpha is sampled from uniform distribution (lower, +// upper) if in testing mode, the alpha is (lower + upper) / 2 +namespace { +class DecomposeAtenRreluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + Value alpha; + if (training) { + // Create a uniform random op with low and high set to `lower` and + // `upper`, respectively. + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value emptyTensor = rewriter.create( + loc, resType, self, zero, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, resType, lower, upper, + constantOneFloat); + alpha = rewriter.create(loc, resType, half, + constantTwoFloat); + } + + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + Value expX = rewriter.create(loc, resType, self); + Value expXM1 = rewriter.create( + loc, resType, expX, constantOneFloat, constantOneFloat); + Value scaledExpXM1; + if (training) { + scaledExpXM1 = + rewriter.create(loc, resType, expXM1, alpha); + } else { + scaledExpXM1 = + rewriter.create(loc, resType, expXM1, alpha); + } + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledExpXM1); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d034e6d1f42..95f742f314e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1039,6 +1039,98 @@ def ElementwiseCeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.4, 0.6, True) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainModule()) +def ElementwiseRreluTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.1, 0.9, True) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule()) +def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.4, 0.6, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalModule()) +def ElementwiseRreluEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.1, 0.9, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalStaticModule()) +def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() From 3fa3589dfa20f575dca9c209bedef67fab52274e Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 17:24:17 +0000 Subject: [PATCH 4/6] finish --- .../Torch/Transforms/DecomposeComplexOps.cpp | 39 +++++++++---------- .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../test_suite/elementwise.py | 12 +++--- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index da71cf99cb2..296c812637b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2415,7 +2415,7 @@ class DecomposeAtenPreluOp : public OpRewritePattern { } // namespace -// rrelu = max(0, x) + min(0, alpha * (exp(x) - 1)) +// rrelu = max(0, x) + min(0, alpha * x) // if in training mode, the alpha is sampled from uniform distribution (lower, // upper) if in testing mode, the alpha is (lower + upper) / 2 namespace { @@ -2438,48 +2438,46 @@ class DecomposeAtenRreluOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "training should be a constant"); } - Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value constantOneFloat = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value constantTwoFloat = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + Value alpha; if (training) { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. Value none = rewriter.create(loc); - Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value emptyTensor = rewriter.create( - loc, resType, self, zero, /*dtype=*/none, /*layout=*/none, + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); alpha = rewriter.create(loc, resType, emptyTensor, /*from=*/lower, /*to=*/upper, /*generator=*/none); } else { - Value half = rewriter.create(loc, resType, lower, upper, - constantOneFloat); - alpha = rewriter.create(loc, resType, half, - constantTwoFloat); + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); } - Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, self); - Value expX = rewriter.create(loc, resType, self); - Value expXM1 = rewriter.create( - loc, resType, expX, constantOneFloat, constantOneFloat); - Value scaledExpXM1; + + Value scaledSelf; if (training) { - scaledExpXM1 = - rewriter.create(loc, resType, expXM1, alpha); + scaledSelf = rewriter.create(loc, resType, self, alpha); } else { - scaledExpXM1 = - rewriter.create(loc, resType, expXM1, alpha); + scaledSelf = rewriter.create(loc, resType, self, alpha); } + Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledExpXM1); + rewriter.create(loc, resType, zeroTensor, scaledSelf); Value rreluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOneFloat); rewriter.replaceOp(op, rreluOutput); @@ -7822,6 +7820,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e7bed646355..f9e2c2fd57d 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 276cc47c1cc..e31b79a88e3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1004,6 +1004,8 @@ "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", @@ -1658,6 +1660,8 @@ "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 95f742f314e..33e23c4fc91 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1051,12 +1051,13 @@ def __init__(self): ] ) def forward(self, x): - return torch.ops.aten.rrelu(x, 0.4, 0.6, True) + res = torch.ops.aten.rrelu(x, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluTrainModule()) def ElementwiseRreluTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1)) + module.forward(tu.rand(1024, 1536)) # ============================================================================== @@ -1070,16 +1071,17 @@ def __init__(self): @annotate_args( [ None, - ([5, 3], torch.float32, True), + ([1024, 1536], torch.float32, True), ] ) def forward(self, x): - return torch.ops.aten.rrelu(x, 0.1, 0.9, True) + res = torch.ops.aten.rrelu(x, 0.1, 0.9, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule()) def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1)) + module.forward(tu.rand(1024, 1536)) # ============================================================================== From 3fcde0c970b9f623a2e718da43ef92f012b27576 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 10:11:25 +0000 Subject: [PATCH 5/6] repair make_fx_tosa --- projects/pt1/e2e_testing/xfail_sets.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4c4513fde69..5526331189c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1943,6 +1943,9 @@ "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", "ElementwiseLogSigmoidModule_basic", + # failed to legalize operation 'torch.aten.rrelu_with_noise' + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", # Shape Related failures "PrimListUnpackNumMismatchModule_basic", "ReshapeExpandModule_basic", From becff9d573040a8c8c0f901c4ef911a591b44f83 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Mon, 29 Apr 2024 12:36:44 +0000 Subject: [PATCH 6/6] fix --- projects/pt1/e2e_testing/xfail_sets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 5526331189c..6a3ef5d2c8c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -383,6 +383,10 @@ "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "EqIntModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic",