From 663608615eea1a92c5ebd7a20a39fe7e0cdb388a Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Wed, 5 Jun 2024 07:12:19 -0500 Subject: [PATCH] Add onnx op LRN lowering This commit adds support for lowering Onnx LRN op to aten. --- .../Conversion/TorchOnnxToTorch/Utils.h | 2 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 115 +++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 2 +- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 131 ++++++++++++++++++ 4 files changed, 248 insertions(+), 2 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index d8d2534f9a0c..fe67676e0868 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -34,7 +34,7 @@ namespace mlir::torch::onnx_c { Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, - SmallVector cstInput); + ArrayRef cstInput); Type getQTorchTypeFromTorchIntType(Type ty); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7e41a7a097c..157744ae9a6d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1445,6 +1445,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand, constAlpha); return success(); }); + patterns.onOp( + "LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t size; + float alpha, beta, bias; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(size, "size", 2) || + binder.f32FloatAttr(alpha, "alpha", 0.0001f) || + binder.f32FloatAttr(beta, "beta", 0.75f) || + binder.f32FloatAttr(bias, "bias", 1.0f)) + return failure(); + Type dtype = resultType.getOptionalDtype(); + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + Value constBias = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(bias)); + // Please refer to the operator description + // for more info on the lowering + // https://onnx.ai/onnx/operators/onnx__LRN.html + + // squared = operand^2 + Location loc = binder.getLoc(); + Torch::ValueTensorType inTy = + cast(operand.getType()); + Value sqOperand = rewriter.create( + loc, inTy, operand, operand); + // view it as n x 1 x c x d0 x d.. + if (!inTy.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input to have sizes"); + } + ArrayRef inTyShape = inTy.getSizes(); + if (inTyShape.size() < 3) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: the input dimensions should be >= 3"); + } + if (inTyShape[1] == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: the second dimension size must be " + "statically known"); + } + SmallVector viewShapeInt{inTyShape[0], 1, inTyShape[1], + inTyShape[2], Torch::kUnknownSize}; + Torch::ValueTensorType reshapeType = + rewriter.getType(viewShapeInt, dtype); + Value viewShapeListVal = + createConstantIntList(binder, rewriter, viewShapeInt); + auto view = rewriter.create( + loc, reshapeType, sqOperand, viewShapeListVal); + // padding + int64_t highPad = (size - 1) / 2; + int64_t lowPad = (size - 1) - highPad; + SmallVector paddingInt{0, 0, 0, 0, lowPad, highPad}; + auto constPadVal = rewriter.create( + loc, rewriter.getType(), + rewriter.getF64FloatAttr(0.0)); + Value paddingListVal = + createConstantIntList(binder, rewriter, paddingInt); + SmallVector paddedShapeInt = viewShapeInt; + paddedShapeInt[2] += size - 1; + Torch::ValueTensorType paddedType = + rewriter.getType(paddedShapeInt, dtype); + auto padded = rewriter.create( + loc, paddedType, view, paddingListVal, constPadVal); + // avg_pool3d + SmallVector kernelSize{size, 1, 1}; + Value kernelSizeList = + createConstantIntList(binder, rewriter, kernelSize); + SmallVector strides{1, 1, 1}; + Value stridesList = createConstantIntList(binder, rewriter, strides); + SmallVector padding{0, 0, 0}; + Value paddingList = createConstantIntList(binder, rewriter, padding); + auto cstCeilMode = + rewriter.create(binder.getLoc(), false); + auto cstCountIncludeMode = + rewriter.create(binder.getLoc(), true); + Value cstNone = rewriter.create(binder.getLoc()); + // Output of pooling is same reshape(view) type because + // of the padding done on the dimensions being pooled. + auto pool = rewriter.create( + loc, reshapeType, padded, kernelSizeList, stridesList, paddingList, + cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone); + // squeeze + auto one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector squeezeShapeInt{ + viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]}; + Torch::ValueTensorType squeezeType = + rewriter.getType(squeezeShapeInt, dtype); + auto squeeze = rewriter.create( + loc, squeezeType, pool, one); + // view as input Type + Value intTyShapeList = + createConstantIntList(binder, rewriter, inTyShape); + auto viewAsInput = rewriter.create( + loc, inTy, squeeze, intTyShapeList); + // mul + add + pow + div + auto mul = rewriter.create( + loc, resultType, viewAsInput, constAlpha); + auto add = rewriter.create(loc, resultType, mul, + constBias, one); + auto pow = rewriter.create( + loc, resultType, add, constBeta); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, pow); + return success(); + }); patterns.onOp( "Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index e7baf2e243fc..fde1cb3a7bdc 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -16,7 +16,7 @@ using namespace mlir::torch::onnx_c; Value mlir::torch::onnx_c::createConstantIntList( OpBinder binder, ConversionPatternRewriter &rewriter, - SmallVector cstInput) { + ArrayRef cstInput) { SmallVector cstValue; for (int64_t i : cstInput) { cstValue.push_back(rewriter.create( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 227eac7d9665..83136099599c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -310,6 +310,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_lrn_default +func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} { + // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5 + // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01 + // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0 + + // CHECK-DAG: %[[I20:.*]] = torch.constant.int 20 + // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I10:.*]] = torch.constant.int 10 + // CHECK-DAG: %[[I3:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]] + + // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]] + + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]] + + // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]] + + // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]] + + // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]] + + // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]] + + // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]] + // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]] + + // CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20 + // CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10 + // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50 + // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]] + + // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]] + // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]] + // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]] + // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]] + // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]] + // CHECK: return %[[OUTPUT]] + %0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> + return %0 : !torch.vtensor<[20,10,3,50],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lrn_with_optionals +func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} { + // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026 + // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209 + // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00 + // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0 + + // CHECK-DAG: %[[I13:.*]] = torch.constant.int 13 + // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I19:.*]] = torch.constant.int 19 + // CHECK-DAG: %[[I100:.+]] = torch.constant.int 100 + // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]] + + // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]] + + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]] + + // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]] + + // CHECK-DAG: %[[I5:.+]] = torch.constant.int 5 + // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]] + + // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]] + + // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]] + + // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]] + // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]] + + // CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13 + // CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19 + // CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100 + // CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200 + // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]] + + // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]] + // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]] + // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]] + // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]] + // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]] + // CHECK: return %[[OUTPUT]] + %none = torch.constant.none + %0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> + return %0 : !torch.vtensor<[13,19,100,200],f32> +} + +// ----- + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>