Skip to content

Commit

Permalink
[MLIR][ONNX] Add OnnxToTorch support for atan and bitwise ops
Browse files Browse the repository at this point in the history
This commit adds the OnnxToTorch support for Atan, Bitshift, BitwiseAnd,
and BitwiseNot op.
This commit also adds the TorchToLinalg support for AtenBitwiseLeftShiftTensorOp.

Signed-Off By: vivekkhandelwal@nod-labs.com
  • Loading branch information
vivekkhandelwal1 committed Nov 28, 2023
1 parent 53fc995 commit dc9ea08
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 21 deletions.
17 changes: 17 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include <string>

namespace mlir::torch::onnx_c {

Expand Down Expand Up @@ -103,6 +104,22 @@ struct OpBinder {
return failure();
}

ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix,
std::string defaultValue = "") {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
value = defaultValue;
return success();
}
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
value = stringAttr.str();
return success();
}
return failure();
}

Torch::ValueTensorType toValidTensorType(Type t) {
auto tt = dyn_cast<Torch::ValueTensorType>(t);
if (tt && tt.hasSizes())
Expand Down
47 changes: 47 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2891,6 +2891,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
}];
}

def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$other
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
54 changes: 53 additions & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -141,6 +142,57 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
// TODO: Asin unimplemented in torch-mlir
// TODO: Asinh unimplemented in torch-mlir
// TODO: Atan unimplemented in torch-mlir
// TODO: Atanh unimplemented in torch-mlir
patterns.onOp("Atan", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenAtanOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp(
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(direction, "direction", ""))
return failure();
if (direction == "LEFT") {
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseLeftShiftTensorOp>(
binder.op, resultType, lhs, rhs);
} else {
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseRightShiftTensorOp>(
binder.op, resultType, lhs, rhs);
}
return success();
});
patterns.onOp(
"BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
std::string direction;
if (binder.tensorOperands(lhs, rhs) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseAndTensorOp>(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("BitwiseNot", 18,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
binder.op, resultType, operand);
return success();
});
}
55 changes: 35 additions & 20 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShRSIOp>(loc, lhs, rhs);
}
if (auto bitwiseLeftShiftTensor =
dyn_cast<AtenBitwiseLeftShiftTensorOp>(op)) {
Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType())
.cast<RankedTensorType>()
.getElementType();
if (!dtype.isa<mlir::IntegerType>()) {
bitwiseLeftShiftTensor.emitError(
"Bitwise_Left_Shift op does not support non-integer input dtype.");
return nullptr;
}
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::ShLIOp>(loc, lhs, rhs);
}
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Expand Down Expand Up @@ -1252,16 +1266,17 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp,
AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
Expand Down Expand Up @@ -1900,16 +1915,16 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp, AtenLogicalNotOp,
AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp,
AtenImagOp>();
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
12 changes: 12 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7680,6 +7680,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.bitwise_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9560,6 +9564,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bitwise_left_shift.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bmm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,9 @@ def aten〇bitwise_xor〇Tensor〡shape(self: List[int], other: List[int]) -> Li
def aten〇bitwise_right_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇bitwise_left_shift〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2454,6 +2457,14 @@ def aten〇bitwise_right_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(_check_two_tensor_op())
def aten〇bitwise_left_shift〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) +
# Different width
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
Expand Down
63 changes: 63 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,6 +3649,69 @@ def ElementwiseBitwiseRightShiftInt8Module_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseBitwiseLeftShiftInt64Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int64, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)


@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt64Module())
def ElementwiseBitwiseLeftShiftInt64Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000), tu.randint(3, 4, low=0, high=64))


class ElementwiseBitwiseLeftShiftInt32Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, 4], torch.int32, True),
([-1, 1], torch.int32, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)


@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt32Module())
def ElementwiseBitwiseLeftShiftInt32Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-1000, high=1000).to(torch.int32), tu.randint(3, 1, low=0, high=32).to(torch.int32))


class ElementwiseBitwiseLeftShiftInt8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int8, True),
([-1, -1], torch.int8, True),
])
def forward(self, lhs, rhs):
return torch.bitwise_left_shift(lhs, rhs)


@register_test_case(module_factory=lambda: ElementwiseBitwiseLeftShiftInt8Module())
def ElementwiseBitwiseLeftShiftInt8Module_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int8), tu.randint(3, 4, low=0, high=8).to(torch.int8))


# ==============================================================================


class ElementwiseBitwiseAndScalarInt64Module(torch.nn.Module):

def __init__(self):
Expand Down
Loading

0 comments on commit dc9ea08

Please sign in to comment.