Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 81 additions & 67 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,39 @@
#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
int64_t dtypeIntTorch;
// TODO: Add complete mapping.
switch (dtypeIntOnnx) {
case 1:
dtypeIntTorch = 6; // float
break;
case 10:
dtypeIntTorch = 5; // half
break;
case 11:
dtypeIntTorch = 7; // double
break;
case 16:
dtypeIntTorch = 15; // bfloat16
break;
default:
dtypeIntTorch = -1; // No dtype
}
// Where are the ONNX and PyTorch dtype enums defined?
// ONNX:
// https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto
// PyTorch:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88

int64_t dtypeIntTorch = [dtypeIntOnnx]() {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 7:
return 5; // int64
case 9:
return 11; // bool
case 10:
return 5; // half
case 11:
return 7; // double
case 16:
return 15; // bfloat16
default:
return -1; // No dtype
}
}();

return dtypeIntTorch;
}

Expand Down Expand Up @@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
return success();
});
patterns.onOp(
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseAnd", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseOr", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseOrTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseNot", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseXor", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseXorTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand All @@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
auto message = llvm::formatv("unimplemented support for the given "
"dtype conversion (onnx 'type' = {0})",
dtypeIntOnnx);
llvm::errs() << message << "\n";
auto y = rewriter.notifyMatchFailure(binder.op, message);

return y;
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand Down Expand Up @@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
unsigned rank = *maybeRank;

SmallVector<int64_t> padding, strides, dilations, outputPadding;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
SmallVector<int64_t> defaultPadding, defaultStrides, defaultDilations,
defaultOutputPadding;
for (unsigned i = 0; i < rank - 2; i++) {
defaultPadding.push_back(0);
defaultStrides.push_back(1);
Expand Down Expand Up @@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
rank));
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative =
rewriter.create<Torch::AtenLtIntOp>(binder.getLoc(), axisScalar, zero);
isNegative = rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(),
isNegative);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);

Torch::BaseTensorType resultTensorType = resultType.cast<Torch::BaseTensorType>();
Torch::BaseTensorType resultTensorType =
resultType.cast<Torch::BaseTensorType>();
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}
// resultTensorType.print(llvm::outs());
Value resultDType =
Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype());
Value resultDType = Torch::getDtypeIntValueForType(
rewriter, loc, resultTensorType.getDtype());

rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
binder.op, resultType, operand, dim, resultDType);
Expand Down
10 changes: 10 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor
return %0 : !torch.vtensor<[3,4],f64>
}

// CHECK-LABEL: @test_cast_FLOAT_to_BOOL
func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 11
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1>
%0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1>
return %0 : !torch.vtensor<[3,4],i1>
}

// CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT
func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 6
Expand Down