diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b9b60e7748b4..3b47162711cb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -10,30 +10,39 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { - int64_t dtypeIntTorch; // TODO: Add complete mapping. - switch (dtypeIntOnnx) { - case 1: - dtypeIntTorch = 6; // float - break; - case 10: - dtypeIntTorch = 5; // half - break; - case 11: - dtypeIntTorch = 7; // double - break; - case 16: - dtypeIntTorch = 15; // bfloat16 - break; - default: - dtypeIntTorch = -1; // No dtype - } + // Where are the ONNX and PyTorch dtype enums defined? + // ONNX: + // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto + // PyTorch: + // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 + + int64_t dtypeIntTorch = [dtypeIntOnnx]() { + switch (dtypeIntOnnx) { + case 1: + return 6; // float + case 7: + return 5; // int64 + case 9: + return 11; // bool + case 10: + return 5; // half + case 11: + return 7; // double + case 16: + return 15; // bfloat16 + default: + return -1; // No dtype + } + }(); + return dtypeIntTorch; } @@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return success(); }); - patterns.onOp( - "BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); - patterns.onOp( - "BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + patterns.onOp("BitwiseAnd", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("BitwiseOr", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("BitwiseNot", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + patterns.onOp("BitwiseXor", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp( "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (dtypeIntTorch == -1) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); + auto message = llvm::formatv("unimplemented support for the given " + "dtype conversion (onnx 'type' = {0})", + dtypeIntOnnx); + llvm::errs() << message << "\n"; + auto y = rewriter.notifyMatchFailure(binder.op, message); + + return y; } Value constDtype = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( unsigned rank = *maybeRank; SmallVector padding, strides, dilations, outputPadding; - SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; + SmallVector defaultPadding, defaultStrides, defaultDilations, + defaultOutputPadding; for (unsigned i = 0; i < rank - 2; i++) { defaultPadding.push_back(0); defaultStrides.push_back(1); @@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cast(operand.getType()).getSizes().size(); Value rankVal = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - rank)); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - + Value axisScalar = rewriter.create( binder.getLoc(), rewriter.getType(), axisTensor); - Value isNegative = - rewriter.create(binder.getLoc(), axisScalar, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); + Value isNegative = rewriter.create( + binder.getLoc(), axisScalar, zero); + isNegative = + rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, rankVal); Value dim = rewriter.create( binder.getLoc(), axisScalar, finalOffset); - Torch::BaseTensorType resultTensorType = resultType.cast(); + Torch::BaseTensorType resultTensorType = + resultType.cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( binder.op, "expected result type to have a dtype"); } // resultTensorType.print(llvm::outs()); - Value resultDType = - Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()); + Value resultDType = Torch::getDtypeIntValueForType( + rewriter, loc, resultTensorType.getDtype()); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, dim, resultDType); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index c2d3c12a7b92..bb02a29cb592 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -332,6 +332,16 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor return %0 : !torch.vtensor<[3,4],f64> } +// CHECK-LABEL: @test_cast_FLOAT_to_BOOL +func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 11 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> +} + // CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6