From 98973466fb26a53ae74628f88a06643b902d5415 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 2 Jan 2024 17:54:10 -0800 Subject: [PATCH 1/2] [linalg] Added `aten.clamp` support with integers to `torch-to-linalg` The lowering for `aten.clamp` did not support integer types. Added support for integer types including a signed integer test. --- .../TorchToLinalg/Uncategorized.cpp | 72 +++++++++++++------ .../Conversion/TorchToLinalg/elementwise.mlir | 17 +++++ 2 files changed, 69 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e947ae73ace0..e7ae424f0792 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1001,13 +1001,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { - Type dtype = converter->convertType(clamp.getType()) - .cast() - .getElementType(); - if (!dtype.isa()) { - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); @@ -1016,20 +1009,59 @@ static Value createLinalgPayloadCalculationForElementwiseOp( clamp.emitError("unimplemented: runtime optional type"); return nullptr; } - auto result = payloadArgs[0]; - if (!min.getType().isa()) { - auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); - } - if (!max.getType().isa()) { - auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); + + Type dtype = converter->convertType(clamp.getType()) + .cast() + .getElementType(); + if (dtype.isa()) { + auto result = payloadArgs[0]; + if (!min.getType().isa()) { + auto minPromoted = convertScalarToDtype(b, loc, min, dtype); + auto pred = b.create(loc, arith::CmpFPredicate::ULT, + result, minPromoted); + result = b.create(loc, pred, minPromoted, result); + } + if (!max.getType().isa()) { + auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); + auto pred = b.create(loc, arith::CmpFPredicate::UGT, + result, maxPromoted); + result = b.create(loc, pred, maxPromoted, result); + } + return result; + } + + if (auto intTy = dtype.dyn_cast()) { + auto result = payloadArgs[0]; + + if (!min.getType().isa()) { + auto minPromoted = + convertScalarToDtype(b, loc, min, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/b.getI64Type()); + auto pred = b.create(loc, + intTy.isUnsigned() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt, + result, minPromoted); + result = b.create(loc, pred, minPromoted, result); + } + if (!max.getType().isa()) { + auto maxPromoted = + convertScalarToDtype(b, loc, max, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/b.getI64Type()); + auto pred = b.create(loc, + intTy.isUnsigned() + ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt, + result, maxPromoted); + result = b.create(loc, pred, maxPromoted, result); + } + return result; } - return result; + + clamp.emitError("unimplemented: non-floating point dtype"); + return nullptr; } if (auto clampTensor = dyn_cast(op)) { AtenClampTensorOp::Adaptor adaptor(operands); diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index bed94f98da2b..b7e57cfc2ad6 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -102,3 +102,20 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_clamp +// CHECK: linalg.generic +// CHECK: arith.trunci +// CHECK: arith.cmpi slt +// CHECK: arith.select +// CHECK: arith.trunci +// CHECK: arith.cmpi sgt +// CHECK: arith.select +func.func @elementwise_clamp(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],si8> { + %int-128 = torch.constant.int -128 + %int127 = torch.constant.int 127 + %0 = torch.aten.clamp %arg0, %int-128, %int127 : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8> + return %0 : !torch.vtensor<[1,3,8,8],si8> +} From 0fb8eeaefaa360d27aa841c64df1aa4b33851dd9 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 4 Jan 2024 15:43:24 -0800 Subject: [PATCH 2/2] fix up for proper elementwise --- .../TorchToLinalg/Uncategorized.cpp | 81 ++++++++----------- .../test_suite/elementwise.py | 28 +++++++ .../Conversion/TorchToLinalg/elementwise.mlir | 17 ---- 3 files changed, 62 insertions(+), 64 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e7ae424f0792..69c5e120e0f9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1013,55 +1013,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(clamp.getType()) .cast() .getElementType(); - if (dtype.isa()) { - auto result = payloadArgs[0]; - if (!min.getType().isa()) { - auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); - } - if (!max.getType().isa()) { - auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); - } - return result; - } - - if (auto intTy = dtype.dyn_cast()) { - auto result = payloadArgs[0]; - - if (!min.getType().isa()) { - auto minPromoted = - convertScalarToDtype(b, loc, min, dtype, - /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/b.getI64Type()); - auto pred = b.create(loc, - intTy.isUnsigned() - ? arith::CmpIPredicate::ult - : arith::CmpIPredicate::slt, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); - } - if (!max.getType().isa()) { - auto maxPromoted = - convertScalarToDtype(b, loc, max, dtype, - /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/b.getI64Type()); - auto pred = b.create(loc, - intTy.isUnsigned() - ? arith::CmpIPredicate::ugt - : arith::CmpIPredicate::sgt, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); - } - return result; + if (!dtype.isa()) { + clamp.emitError("unimplement type for clamp"); + return nullptr; } - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; + Type dstOriginalDtype = clamp.getType().cast().getDtype(); + bool isUnsigned = isa(dstOriginalDtype); + if (auto intTy = dstOriginalDtype.dyn_cast()) { + isUnsigned = intTy.isUnsigned(); + } + auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { + clamp = convertScalarToDtype(b, loc, clamp, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/dstOriginalDtype); + + Value pred; + if (dtype.isa()) { + auto cmp = + getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; + pred = b.create(loc, cmp, input, clamp); + } else if (dtype.isa()) { + auto cmp = + isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; + if (getMax) + cmp = arith::invertPredicate(cmp); + pred = b.create(loc, cmp, input, clamp); + } + return b.create(loc, pred, clamp, input); + }; + + auto result = payloadArgs[0]; + if (!min.getType().isa()) + result = cmpSelect(result, min, /*getMax=*/false); + if (!max.getType().isa()) + result = cmpSelect(result, max, /*getMax=*/true); + return result; } if (auto clampTensor = dyn_cast(op)) { AtenClampTensorOp::Adaptor adaptor(operands); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 33c420a1c517..2f9d840606a4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True) + ]) + def forward(self, x): + min = -5 + max = 5 + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module()) +def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8)) + + +# ============================================================================== + + + class ElementwiseClampMinTensorFloatModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index b7e57cfc2ad6..bed94f98da2b 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -102,20 +102,3 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } - -// ----- - -// CHECK-LABEL: func.func @elementwise_clamp -// CHECK: linalg.generic -// CHECK: arith.trunci -// CHECK: arith.cmpi slt -// CHECK: arith.select -// CHECK: arith.trunci -// CHECK: arith.cmpi sgt -// CHECK: arith.select -func.func @elementwise_clamp(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],si8> { - %int-128 = torch.constant.int -128 - %int127 = torch.constant.int 127 - %0 = torch.aten.clamp %arg0, %int-128, %int127 : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8> - return %0 : !torch.vtensor<[1,3,8,8],si8> -}