diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index be785cbe6ae7..8a0071d6f7c3 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -963,6 +963,7 @@ "ElementwiseMaximumIntModule_basic", "ElementwiseMaxOtherIntModule_basic", "ElementwiseMaxOtherModule_basic", + "GluStaticModule_basic", "ViewDoubleMergeStaticModule_basic", "ViewCollapseOnesMiddleModule_basic", "ViewFiveTestStaticModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9ac354ed8cd3..c22a252de588 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4189,6 +4189,30 @@ def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [ }]; } +def Torch_AtenGluOp : Torch_Op<"aten.glu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::glu : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenGluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 47d76219f199..1a72c4d78d00 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6382,6 +6382,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.glu\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: glu's dim size must be multiply of 2\"\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %13 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %14 = torch.aten.add.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %14 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.remainder.int %2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.prim.ListConstruct %7 : (!torch.int) -> !torch.list\n" +" %9 = torch.aten.add.t %5, %8 : !torch.list, !torch.list -> !torch.list\n" +" %10 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.slice.t %arg0, %10, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %12 = torch.aten.add.t %9, %11 : !torch.list, !torch.list -> !torch.list\n" +" return %12 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8863,6 +8896,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.glu\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.scatter_reduce.two\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.str, %arg5: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0bdfca26ddc1..2cf3cc742f92 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -361,6 +361,48 @@ class DecomposeAtenNarrowTensorOp }; } // namespace +namespace { +class DecomposeAtenGluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenGluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value dim = op.getDim(); + + auto outputTy = op.getType().dyn_cast(); + if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected output type having sizes and dtype"); + } + + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value dimSize = rewriter.create(loc, self, dim); + Value two = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + + Value remainder = rewriter.create(loc, dimSize, two); + Value eqOrNot = rewriter.create(loc, remainder, zero); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("AtenGluOp's dim size must be multiply of 2")); + + Value splitLength = rewriter.create(loc, dimSize, two); + Value a = rewriter.create(loc, outputTy, self, dim, zero, + splitLength); + Value b = rewriter.create(loc, outputTy, self, dim, + splitLength, splitLength); + // a⊗σ(b) + Value sigmoidB = rewriter.create(loc, outputTy, b); + Value result = rewriter.create(loc, outputTy, a, sigmoidB); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenZeroOp : public OpRewritePattern { @@ -5289,6 +5331,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index aadd27cb9462..d36488fa799f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 958df70d575a..4b41c0a363a4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -167,6 +167,12 @@ def aten〇relu6〡shape(self: List[int]) -> List[int]: def aten〇round〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]: + if dim < 0: + dim += len(self) + assert self[dim] % 2 == 0, "glu's dim size must be multiply of 2" + return self[:dim] + [self[dim] // 2] + self[dim+1:] + def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) @@ -1932,6 +1938,11 @@ def aten〇round〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(100,)], dim=0)) +def aten〇glu〡dtype(self_rank_dtype: Tuple[int, int], dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype), "sum") for dtype in _SORTED_TORCH_TYPES]) def aten〇scatter_reduce〇two〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int], reduce: str, include_self: bool = True) -> int: diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 12e929fbea99..1a39989f84a5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -354,6 +354,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_real : (Tensor) -> (Tensor)") emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") + emit("aten::glu : (Tensor, int) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 3b2997c3e482..918a22b86f78 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3685,3 +3685,21 @@ def forward(self, x): @register_test_case(module_factory=lambda: ElementwiseBitwiseAndScalarInt8Module()) def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int8)) + +# ============================================================================== + +class GluStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 24, 5], torch.float32, True) + ]) + def forward(self, x): + return torch.ops.aten.glu(x, dim=1) + +@register_test_case(module_factory=lambda: GluStaticModule()) +def GluStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 24, 5))