Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace mlir::torch::onnx_c {

Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput);
ArrayRef<int64_t> cstInput);

Type getQTorchTypeFromTorchIntType(Type ty);

Expand Down
115 changes: 115 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand, constAlpha);
return success();
});
patterns.onOp(
"LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
int64_t size;
float alpha, beta, bias;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(size, "size", 2) ||
binder.f32FloatAttr(alpha, "alpha", 0.0001f) ||
binder.f32FloatAttr(beta, "beta", 0.75f) ||
binder.f32FloatAttr(bias, "bias", 1.0f))
return failure();
Type dtype = resultType.getOptionalDtype();
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));
Value constBeta = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(beta));
Value constBias = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(bias));
// Please refer to the operator description
// for more info on the lowering
// https://onnx.ai/onnx/operators/onnx__LRN.html

// squared = operand^2
Location loc = binder.getLoc();
Torch::ValueTensorType inTy =
cast<Torch::ValueTensorType>(operand.getType());
Value sqOperand = rewriter.create<Torch::AtenMulTensorOp>(
loc, inTy, operand, operand);
// view it as n x 1 x c x d0 x d..
if (!inTy.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expected input to have sizes");
}
ArrayRef<int64_t> inTyShape = inTy.getSizes();
if (inTyShape.size() < 3) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: the input dimensions should be >= 3");
}
if (inTyShape[1] == Torch::kUnknownSize) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported: the second dimension size must be "
"statically known");
}
SmallVector<int64_t, 5> viewShapeInt{inTyShape[0], 1, inTyShape[1],
inTyShape[2], Torch::kUnknownSize};
Torch::ValueTensorType reshapeType =
rewriter.getType<Torch::ValueTensorType>(viewShapeInt, dtype);
Value viewShapeListVal =
createConstantIntList(binder, rewriter, viewShapeInt);
auto view = rewriter.create<Torch::AtenViewOp>(
loc, reshapeType, sqOperand, viewShapeListVal);
// padding
int64_t highPad = (size - 1) / 2;
int64_t lowPad = (size - 1) - highPad;
SmallVector<int64_t> paddingInt{0, 0, 0, 0, lowPad, highPad};
auto constPadVal = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(0.0));
Value paddingListVal =
createConstantIntList(binder, rewriter, paddingInt);
SmallVector<int64_t, 5> paddedShapeInt = viewShapeInt;
paddedShapeInt[2] += size - 1;
Torch::ValueTensorType paddedType =
rewriter.getType<Torch::ValueTensorType>(paddedShapeInt, dtype);
auto padded = rewriter.create<Torch::AtenConstantPadNdOp>(
loc, paddedType, view, paddingListVal, constPadVal);
// avg_pool3d
SmallVector<int64_t, 3> kernelSize{size, 1, 1};
Value kernelSizeList =
createConstantIntList(binder, rewriter, kernelSize);
SmallVector<int64_t, 3> strides{1, 1, 1};
Value stridesList = createConstantIntList(binder, rewriter, strides);
SmallVector<int64_t, 3> padding{0, 0, 0};
Value paddingList = createConstantIntList(binder, rewriter, padding);
auto cstCeilMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
auto cstCountIncludeMode =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
// Output of pooling is same reshape(view) type because
// of the padding done on the dimensions being pooled.
auto pool = rewriter.create<Torch::AtenAvgPool3dOp>(
loc, reshapeType, padded, kernelSizeList, stridesList, paddingList,
cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone);
// squeeze
auto one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<int64_t, 5> squeezeShapeInt{
viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]};
Torch::ValueTensorType squeezeType =
rewriter.getType<Torch::ValueTensorType>(squeezeShapeInt, dtype);
auto squeeze = rewriter.create<Torch::AtenSqueezeDimOp>(
loc, squeezeType, pool, one);
// view as input Type
Value intTyShapeList =
createConstantIntList(binder, rewriter, inTyShape);
auto viewAsInput = rewriter.create<Torch::AtenViewOp>(
loc, inTy, squeeze, intTyShapeList);
// mul + add + pow + div
auto mul = rewriter.create<Torch::AtenMulScalarOp>(
loc, resultType, viewAsInput, constAlpha);
auto add = rewriter.create<Torch::AtenAddScalarOp>(loc, resultType, mul,
constBias, one);
auto pow = rewriter.create<Torch::AtenPowTensorScalarOp>(
loc, resultType, add, constBeta);

rewriter.replaceOpWithNewOp<Torch::AtenDivTensorOp>(
binder.op, resultType, operand, pow);
return success();
});
patterns.onOp(
"Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using namespace mlir::torch::onnx_c;

Value mlir::torch::onnx_c::createConstantIntList(
OpBinder binder, ConversionPatternRewriter &rewriter,
SmallVector<int64_t> cstInput) {
ArrayRef<int64_t> cstInput) {
SmallVector<Value> cstValue;
for (int64_t i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
Expand Down
131 changes: 131 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor

// -----

// CHECK-LABEL: func.func @test_lrn_default
func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0

// CHECK-DAG: %[[I20:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I10:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3:.+]] = torch.constant.int 3
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]]

// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]

// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]]

// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]

// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]]

// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]

// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]

// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]

// CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20
// CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10
// CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3
// CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]]

// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32>
return %0 : !torch.vtensor<[20,10,3,50],f32>
}

// -----

// CHECK-LABEL: func.func @test_lrn_with_optionals
func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} {
// CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true
// CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false
// CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026
// CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209
// CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00
// CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0

// CHECK-DAG: %[[I13:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I19:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100:.+]] = torch.constant.int 100
// CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1
// CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]]

// CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]]

// CHECK-DAG: %[[I0:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2
// CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]]

// CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]]

// CHECK-DAG: %[[I5:.+]] = torch.constant.int 5
// CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1
// CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]]

// CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1
// CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1
// CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]]

// CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0
// CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0
// CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]]

// CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]]
// CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]]

// CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13
// CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19
// CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100
// CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200
// CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]]

// CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]]
// CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]]
// CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]]
// CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]]
// CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]]
// CHECK: return %[[OUTPUT]]
%none = torch.constant.none
%0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32>
return %0 : !torch.vtensor<[13,19,100,200],f32>
}

// -----

// CHECK-LABEL: @test_matmul_2d
func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32>
Expand Down