From 20ebed15e0bee204e3fbe3eca53d8f31bfc46c7c Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 18 Oct 2024 11:28:38 -0400 Subject: [PATCH 1/4] Support convolution with `valid` padding. --- lib/Conversion/TorchToLinalg/Linear.cpp | 6 +++ lib/Conversion/TorchToStablehlo/Linear.cpp | 5 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 ++ projects/pt1/e2e_testing/xfail_sets.py | 3 ++ .../torch_mlir_e2e_test/test_suite/conv.py | 52 +++++++++++++++++++ 5 files changed, 69 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9c914690bbf4..e6f9b81b8436 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -832,6 +832,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); + if (paddingIntValues.size() == 1) { + for (size_t iDim = 1; iDim < numSpatialDims; iDim++) { + paddingIntValues.push_back(paddingIntValues[0]); + } + } + SmallVector outputPaddingIntValues; if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues)) diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b42ed7cc7722..88617f139c96 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -750,6 +750,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); } + if (padding.size() == 1) { + for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) { + padding.push_back(padding[0]); + } + } SmallVector dilation; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) { return rewriter.notifyMatchFailure(op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e75d358b068d..64e003d39dcc 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2187,6 +2187,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + if (padding_2d.size() == 1) { + padding_2d.push_back(padding_2d[0]); + } // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e0011b9a347e..0770c51b5180 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -33,6 +33,9 @@ # if a dimension is specified in all expand lists, and not in sumdim list. # This is a bug in the implementation of _trilinear in PyTorch. "Aten_TrilinearModuleZerodDimBug_basic", + # TorchScript to the backend contract fails for conv.padding specified as str + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index e6332579d575..147885b442c7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -191,6 +191,58 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) +class Conv2dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) +def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + +class Conv2dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d( + 1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1 + ) + self.train(False) + + @export + @annotate_args( + [ + None, + ([1, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) +def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(1, 5, 6) + module.forward(t) + + # ============================================================================== From 21089d3d6f2f908c9c6d922466076c330d5cb183 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 12 Nov 2024 17:01:38 -0500 Subject: [PATCH 2/4] Add conv2d_padding op to torch. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 29 +++++++ .../Transforms/AbstractInterpLibrary.cpp | 42 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 78 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 3 - .../build_tools/abstract_interp_lib_gen.py | 15 ++++ .../build_tools/torch_ods_gen.py | 9 +++ .../torch_mlir_e2e_test/test_suite/conv.py | 63 ++++++++------- python/torch_mlir/fx.py | 1 + 8 files changed, 209 insertions(+), 31 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 28764009a393..79b30c7d66e2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6713,6 +6713,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConv2dPaddingOp : Torch_Op<"aten.conv2d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv2dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv2dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1cc02a48f37f..4258026c948d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10024,6 +10024,48 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int-1 = torch.constant.int -1\n" +" %str = torch.constant.str \"same\"\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.operator \"aten.mul.left_t\"(%3, %int2) : (!torch.list, !torch.int) -> !torch.list \n" +" %5 = torch.aten.eq.str %arg4, %str : !torch.str, !torch.str -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" %7 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %8 = torch.aten.__range_length %int1, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %int2, %8 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg7: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg5, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %2, %arg7 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__derive_index %arg7, %int1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.mul.int %11, %14 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.floordiv.int %15, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.mul.int %int2, %13 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten._set_item.t %4, %17, %16 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %19 = torch.aten.sub.int %15, %16 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.mul.int %int2, %13 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten._set_item.t %4, %21, %19 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %6 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2f276b1a296f..753f3b4ce90f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5169,6 +5169,83 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv2d.padding to aten.convolution +namespace { +class DecomposeAtenConv2dPaddingOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConv2dPaddingOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + Value weight = op.getWeight(); + std::optional maybeRank = getTensorRank(weight); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected weight to have a rank"); + } + unsigned rank = *maybeRank; + if (rank != 4) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D convolutions supported."); + + std::string padding_str; + if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str))) + return rewriter.notifyMatchFailure(op, + "padding must be a constant string"); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + SmallVector paddingValues; + if (padding_str == "valid") { + for (unsigned iRank = 0; iRank < rank; iRank++) + paddingValues.push_back(zero); + } else { + + SmallVector dilation; + getListConstructElements(op.getDilation(), dilation); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value two = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + for (unsigned iRank = 2; iRank < rank; iRank++) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(iRank)); + Value kernelSize = + rewriter.create(loc, weight, dim); + Value kernelSizeMinusOne = + rewriter.create(loc, kernelSize, one); + Value totalPadding = rewriter.create( + loc, dilation[iRank - 2], kernelSizeMinusOne); + Value leftPadding = + rewriter.create(loc, totalPadding, two); + Value rightPadding = + rewriter.create(loc, totalPadding, leftPadding); + paddingValues.push_back(leftPadding); + paddingValues.push_back(rightPadding); + } + } + + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + paddingValues); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), padding, op.getDilation(), cstFalse, emptyList, + op.getGroups()); + + return success(); + } +}; +} // namespace + // Decompose aten.conv3d to aten.convolution namespace { class DecomposeAtenConv3dOp : public OpRewritePattern { @@ -10940,6 +11017,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + // addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0770c51b5180..e0011b9a347e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -33,9 +33,6 @@ # if a dimension is specified in all expand lists, and not in sumdim list. # This is a bug in the implementation of _trilinear in PyTorch. "Aten_TrilinearModuleZerodDimBug_basic", - # TorchScript to the backend contract fails for conv.padding specified as str - "Conv2dWithValidPaddingModule_basic", - "Conv2dWithSamePaddingModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8dfacca3238b..20b123594732 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1839,6 +1839,21 @@ def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weigh def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv2d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: str = "valid", dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: + kernel_size = [weight[2], weight[3]] + padding_int = [0, 0] * len(kernel_size) + if padding == "same": + for d, k, i in zip( + dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + ): + total_padding = d * (k - 1) + left_pad = total_padding // 2 + padding_int[2 * i] = left_pad + padding_int[2 * i + 1] = ( + total_padding - left_pad + ) + return upstream_shape_functions.conv2d(input, weight, bias, stride, padding_int, dilation, groups) + def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 31916f7fe896..4ca5082c90b5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -574,12 +574,21 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" + ) emit( "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" + ) emit( "aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 147885b442c7..93633d7b47a9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -191,32 +191,6 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) -class Conv2dWithValidPaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - torch.manual_seed(0) - self.conv = torch.nn.Conv2d( - 1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1 - ) - self.train(False) - - @export - @annotate_args( - [ - None, - ([1, 5, 6], torch.float32, True), - ] - ) - def forward(self, x): - return self.conv(x) - - -@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) -def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): - t = tu.rand(1, 5, 6) - module.forward(t) - - class Conv2dWithSamePaddingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -230,7 +204,7 @@ def __init__(self): @annotate_args( [ None, - ([1, 5, 6], torch.float32, True), + ([1, 1, 5, 6], torch.float32, True), ] ) def forward(self, x): @@ -239,7 +213,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): - t = tu.rand(1, 5, 6) + t = tu.rand(1, 1, 5, 6) module.forward(t) @@ -1212,6 +1186,39 @@ def Conv3dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv3dWithSingletonPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding=[1], + dilation=[1, 1, 1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv3dWithSingletonPaddingModule()) +def Conv3dWithSingletonPaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class ConvTbcModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index cfe873480370..31c576c7d1f8 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -76,6 +76,7 @@ def export_and_import( prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) else: prog = torch.export.export(f, args, kwargs) + breakpoint() if decomposition_table is None: decomposition_table = get_decomposition_table() if decomposition_table: From d09d568553f376266b0ac68610f098dd27349657 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 19 Nov 2024 10:28:43 -0500 Subject: [PATCH 3/4] Add tests and update xfail_sets.py --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 58 ++++++++ lib/Conversion/TorchToLinalg/Linear.cpp | 6 - lib/Conversion/TorchToStablehlo/Linear.cpp | 5 - lib/Conversion/TorchToTosa/TorchToTosa.cpp | 3 - .../Transforms/AbstractInterpLibrary.cpp | 71 ++++++---- .../Torch/Transforms/DecomposeComplexOps.cpp | 37 ++--- projects/pt1/e2e_testing/xfail_sets.py | 20 +++ .../build_tools/abstract_interp_lib_gen.py | 34 +++-- .../build_tools/torch_ods_gen.py | 2 +- .../torch_mlir_e2e_test/test_suite/conv.py | 130 ++++++++++++++++-- python/torch_mlir/fx.py | 1 - 11 files changed, 290 insertions(+), 77 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 79b30c7d66e2..b38c03f644bb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6684,6 +6684,35 @@ def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ }]; } +def Torch_AtenConv3dPaddingOp : Torch_Op<"aten.conv3d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv3dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv3dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6771,6 +6800,35 @@ def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ }]; } +def Torch_AtenConv1dPaddingOp : Torch_Op<"aten.conv1d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv1dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv1dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index e6f9b81b8436..9c914690bbf4 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -832,12 +832,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "only support padding from a list construct"); paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(), paddingIntValues); - if (paddingIntValues.size() == 1) { - for (size_t iDim = 1; iDim < numSpatialDims; iDim++) { - paddingIntValues.push_back(paddingIntValues[0]); - } - } - SmallVector outputPaddingIntValues; if (!getListConstructElements(op.getOutputPadding(), outputPaddingIntValues)) diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 88617f139c96..b42ed7cc7722 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -750,11 +750,6 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); } - if (padding.size() == 1) { - for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) { - padding.push_back(padding[0]); - } - } SmallVector dilation; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) { return rewriter.notifyMatchFailure(op, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 64e003d39dcc..e75d358b068d 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2187,9 +2187,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); - if (padding_2d.size() == 1) { - padding_2d.push_back(padding_2d[0]); - } // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4258026c948d..d9c68082a5c7 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10025,51 +10025,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__._conv_padding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %int-1 = torch.constant.int -1\n" " %str = torch.constant.str \"same\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n" " %int2 = torch.constant.int 2\n" -" %int3 = torch.constant.int 3\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" -" %0 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %1 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" -" %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = torch.operator \"aten.mul.left_t\"(%3, %int2) : (!torch.list, !torch.int) -> !torch.list \n" -" %5 = torch.aten.eq.str %arg4, %str : !torch.str, !torch.str -> !torch.bool\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %4 = torch.aten.mul.left_t %3, %2 : !torch.list, !torch.int -> !torch.list\n" +" %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n" " torch.prim.If %5 -> () {\n" -" %7 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" -" %8 = torch.aten.__range_length %int1, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %9 = torch.prim.ListConstruct %7, %int2, %8 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" " %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" " torch.prim.Loop %10, %true, init() {\n" -" ^bb0(%arg7: !torch.int):\n" -" %11 = torch.aten.__getitem__.t %arg5, %arg7 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__getitem__.t %2, %arg7 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.aten.__derive_index %arg7, %int1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %14 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %15 = torch.aten.mul.int %11, %14 : !torch.int, !torch.int -> !torch.int\n" -" %16 = torch.aten.floordiv.int %15, %int2 : !torch.int, !torch.int -> !torch.int\n" -" %17 = torch.aten.mul.int %int2, %13 : !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten._set_item.t %4, %17, %16 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" %19 = torch.aten.sub.int %15, %16 : !torch.int, !torch.int -> !torch.int\n" -" %20 = torch.aten.mul.int %int2, %13 : !torch.int, !torch.int -> !torch.int\n" -" %21 = torch.aten.add.int %20, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %22 = torch.aten._set_item.t %4, %21, %19 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" ^bb0(%arg3: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list, !torch.int, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.If.yield\n" " }\n" -" %6 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" -" return %6 : !torch.list\n" +" return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" " %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" @@ -10139,6 +10152,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 753f3b4ce90f..c690984c1c8c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5169,13 +5169,13 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace -// Decompose aten.conv2d.padding to aten.convolution +// Decompose aten.conv(1/2/3)d.padding to aten.convolution namespace { -class DecomposeAtenConv2dPaddingOp - : public OpRewritePattern { +template +class DecomposeAtenConvPaddingOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConv2dPaddingOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvPaddingOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -5186,9 +5186,11 @@ class DecomposeAtenConv2dPaddingOp return rewriter.notifyMatchFailure(op, "expected weight to have a rank"); } unsigned rank = *maybeRank; - if (rank != 4) + // first 2 dimensions of weight corresponds to out_channels and in_channels / \ + groups + if (rank < 3) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolutions supported."); + op, "ConvPaddingOp weight must be at least 3 dimensional."); std::string padding_str; if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str))) @@ -5200,8 +5202,10 @@ class DecomposeAtenConv2dPaddingOp SmallVector paddingValues; if (padding_str == "valid") { - for (unsigned iRank = 0; iRank < rank; iRank++) + // valid means no padding + for (unsigned iRank = 2; iRank < rank; iRank++) { paddingValues.push_back(zero); + } } else { SmallVector dilation; @@ -5218,14 +5222,10 @@ class DecomposeAtenConv2dPaddingOp rewriter.create(loc, weight, dim); Value kernelSizeMinusOne = rewriter.create(loc, kernelSize, one); - Value totalPadding = rewriter.create( + Value padding = rewriter.create( loc, dilation[iRank - 2], kernelSizeMinusOne); - Value leftPadding = - rewriter.create(loc, totalPadding, two); - Value rightPadding = - rewriter.create(loc, totalPadding, leftPadding); - paddingValues.push_back(leftPadding); - paddingValues.push_back(rightPadding); + padding = rewriter.create(loc, padding, two); + paddingValues.push_back(padding); } } @@ -11017,8 +11017,13 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - // addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e0011b9a347e..d9551162fca9 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2055,6 +2055,8 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", @@ -2545,6 +2547,8 @@ "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", @@ -2872,6 +2876,8 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -2884,7 +2890,11 @@ "Conv2dQInt8PerChannelModule_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -3572,6 +3582,8 @@ "ContainsIntList_True", "Conv1dModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -3582,6 +3594,8 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -4193,6 +4207,8 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", @@ -4208,7 +4224,11 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 20b123594732..643651a7b1f8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1839,24 +1839,32 @@ def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weigh def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) -def aten〇conv2d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: str = "valid", dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: - kernel_size = [weight[2], weight[3]] - padding_int = [0, 0] * len(kernel_size) +def _conv_padding(weight: List[int], dilation: List[int], padding: str): + rank = len(weight) + # first 2 dimensions of weight corresponds to out_channels and in_channels/groups + num_unpadded_dims = 2 + assert rank > num_unpadded_dims, "conv: weight must be at least 3 dimensional." + num_kernel_elems = rank - num_unpadded_dims + padding_int = [0] * num_kernel_elems if padding == "same": - for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + for d, i in zip( + dilation, range(num_kernel_elems - 1, -1, -1) ): - total_padding = d * (k - 1) - left_pad = total_padding // 2 - padding_int[2 * i] = left_pad - padding_int[2 * i + 1] = ( - total_padding - left_pad - ) + padding_val = d * (weight[num_unpadded_dims+i] - 1) + padding_int[i] = padding_val // 2 + return padding_int + +def aten〇conv2d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: str = "valid", dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) return upstream_shape_functions.conv2d(input, weight, bias, stride, padding_int, dilation, groups) def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv3d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: str = "valid", dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) + return upstream_shape_functions.conv3d(input, weight, bias, stride, padding_int, dilation, groups) + def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) @@ -1898,6 +1906,10 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) +def aten〇conv1d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: str = "valid", dilation: List[int] = (1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding_int, dilation, transposed=False, output_padding=[], groups=1) + def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4ca5082c90b5..a3a2383b509a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -575,7 +575,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) emit( - "aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + "aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 93633d7b47a9..7a45dd7fc0ce 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -195,16 +195,14 @@ class Conv2dWithSamePaddingModule(torch.nn.Module): def __init__(self): super().__init__() torch.manual_seed(0) - self.conv = torch.nn.Conv2d( - 1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1 - ) + self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="same") self.train(False) @export @annotate_args( [ None, - ([1, 1, 5, 6], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), ] ) def forward(self, x): @@ -213,7 +211,31 @@ def forward(self, x): @register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): - t = tu.rand(1, 1, 5, 6) + t = tu.rand(5, 2, 10, 20) + module.forward(t) + + +class Conv2dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="valid") + self.train(False) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) +def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10, 20) module.forward(t) @@ -1120,6 +1142,63 @@ def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestU module.forward(inputVec, weight) +class Conv1dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv1d(2, 10, 3, bias=False, padding="same") + self.train(False) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv1dWithSamePaddingModule()) +def Conv1dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10) + module.forward(t) + + +class Conv1dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, + weight, + bias=bias, + stride=[1], + padding="valid", + dilation=[1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv1dWithValidPaddingModule()) +def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1186,7 +1265,7 @@ def Conv3dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) -class Conv3dWithSingletonPaddingModule(torch.nn.Module): +class Conv3dWithSamePaddingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1205,14 +1284,47 @@ def forward(self, inputVec, weight, bias): weight, bias=bias, stride=[1, 1, 1], - padding=[1], + padding="same", + dilation=[1, 1, 1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv3dWithSamePaddingModule()) +def Conv3dWithSamePaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + +class Conv3dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding="valid", dilation=[1, 1, 1], groups=1, ) -@register_test_case(module_factory=lambda: Conv3dWithSingletonPaddingModule()) -def Conv3dWithSingletonPaddingModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: Conv3dWithValidPaddingModule()) +def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils): inputVec = tu.rand(2, 2, 6, 6, 6) weight = torch.randn(8, 2, 3, 3, 3) bias = torch.randn(8) diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 31c576c7d1f8..cfe873480370 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -76,7 +76,6 @@ def export_and_import( prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) else: prog = torch.export.export(f, args, kwargs) - breakpoint() if decomposition_table is None: decomposition_table = get_decomposition_table() if decomposition_table: From cec8af37be46bd7011ca1b4d71f41f31620228dc Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 19 Nov 2024 12:22:22 -0500 Subject: [PATCH 4/4] Fix multiline comment. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c690984c1c8c..2d24d3e2c20d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5186,8 +5186,7 @@ class DecomposeAtenConvPaddingOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "expected weight to have a rank"); } unsigned rank = *maybeRank; - // first 2 dimensions of weight corresponds to out_channels and in_channels / \ - groups + // first 2 dimensions of weight are out_channels and in_channels / groups if (rank < 3) return rewriter.notifyMatchFailure( op, "ConvPaddingOp weight must be at least 3 dimensional.");