Skip to content

Commit

Permalink
[Tcp] Handle int inputs in sqrt (#2467)
Browse files Browse the repository at this point in the history
This PR adds support for integer inputs in `tcp.sqrt`.
  • Loading branch information
navahgar committed Sep 16, 2023
1 parent 4d90ab4 commit 0ea235e
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class Tcp_UnaryElementwiseOp<string mnemonic, list<Trait> traits = []> :
Tcp_Op<mnemonic, !listconcat(traits, [
Pure,
Elementwise,
SameOperandsAndResultShape,
SameOperandsAndResultElementType])> {
SameOperandsAndResultShape])> {
}

class Tcp_BinaryElementwiseOp<string mnemonic, list<Trait> traits = []> :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include "torch-mlir-dialects/Dialect/Tcp/IR/TcpEnums.td"

include "mlir/IR/OpBase.td"

def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> {
def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
let summary = "Computes tanh of input, elementwise";

let description = [{
Expand All @@ -33,7 +33,7 @@ def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> {
def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp", [SameOperandsAndResultElementType]> {
let summary = "Clamps input tensor to the given min and/or max";

let description = [{
Expand Down Expand Up @@ -65,7 +65,7 @@ def Tcp_ClampOp : Tcp_UnaryElementwiseOp<"clamp"> {
let hasVerifier = 1;
}

def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid"> {
def Tcp_SigmoidOp : Tcp_UnaryElementwiseOp<"sigmoid", [SameOperandsAndResultElementType]> {
let summary = "Computes sigmoid of input, elementwise";

let description = [{
Expand Down Expand Up @@ -312,19 +312,19 @@ def Tcp_IsolatedGroupOp : Tcp_Op<"isolated_group", [
let hasVerifier = 1;
}

def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt", [SameOperandsAndResultElementType]> {
def Tcp_SqrtOp : Tcp_UnaryElementwiseOp<"sqrt"> {
let summary = "Computes square root of input, elementwise";

let description = [{
Computes elementwise square root of the input tensor.
}];

let arguments = (ins
Tcp_FloatOrComplexTensor:$in
Tcp_FloatOrIntTensor:$in
);

let results = (outs
Tcp_FloatOrComplexTensor:$out
Tcp_FloatTensor:$out
);

let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
Expand All @@ -351,7 +351,7 @@ def Tcp_ConcatOp : Tcp_Op<"concat", [SameOperandsAndResultElementType]> {
let hasVerifier = 1;
}

def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> {
def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil", [SameOperandsAndResultElementType]> {
let summary = "Computes ceil of input, elementwise";

let description = [{
Expand All @@ -369,7 +369,7 @@ def Tcp_CeilOp : Tcp_UnaryElementwiseOp<"ceil"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> {
def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor", [SameOperandsAndResultElementType]> {
let summary = "Computes floor of input, elementwise";

let description = [{
Expand All @@ -387,7 +387,7 @@ def Tcp_FloorOp : Tcp_UnaryElementwiseOp<"floor"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> {
def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos", [SameOperandsAndResultElementType]> {
let summary = "Computes cosine of input, elementwise";

let description = [{
Expand All @@ -405,7 +405,7 @@ def Tcp_CosOp : Tcp_UnaryElementwiseOp<"cos"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> {
def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin", [SameOperandsAndResultElementType]> {
let summary = "Computes sine of input, elementwise";

let description = [{
Expand All @@ -423,7 +423,7 @@ def Tcp_SinOp : Tcp_UnaryElementwiseOp<"sin"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> {
def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs", [SameOperandsAndResultElementType]> {
let summary = "Computes absolute of input, elementwise";

let description = [{
Expand All @@ -441,7 +441,7 @@ def Tcp_AbsOp : Tcp_UnaryElementwiseOp<"abs"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> {
def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log", [SameOperandsAndResultElementType]> {
let summary = "Computes natural logarithm of input, elementwise";

let description = [{
Expand All @@ -459,7 +459,7 @@ def Tcp_LogOp : Tcp_UnaryElementwiseOp<"log"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> {
def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg", [SameOperandsAndResultElementType]> {
let summary = "Computes the negation of input, elementwise";

let description = [{
Expand All @@ -477,7 +477,7 @@ def Tcp_NegOp : Tcp_UnaryElementwiseOp<"neg"> {
let assemblyFormat = "$in attr-dict `:` type($in) `->` type($out)";
}

def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan"> {
def Tcp_AtanOp : Tcp_UnaryElementwiseOp<"atan", [SameOperandsAndResultElementType]> {
let summary = "Computes the arcus tangent value of input, elementwise";

let description = [{
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
return rewriter.notifyMatchFailure(
catOp, "aten.cat operands must be a list of tensors");

SmallVector tensorInputs = getTypeConvertedValues(
rewriter, catOp->getLoc(), getTypeConverter(), inputs);
auto tensorInputs = getTypeConvertedValues(rewriter, catOp->getLoc(),
getTypeConverter(), inputs);

int64_t dim;
if (!matchPattern(catOp.getDim(), m_TorchConstantInt(&dim)))
Expand Down
60 changes: 34 additions & 26 deletions lib/Conversion/TorchToTcp/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,14 @@ class ConvertAtenReluOp : public OpConversionPattern<AtenReluOp> {
}
};

class ConvertAtenAbsOp : public OpConversionPattern<AtenAbsOp> {
template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryIntOrFpOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenAbsOp>::OpConversionPattern;
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult
matchAndRewrite(AtenAbsOp op, OpAdaptor adaptor,
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();
RankedTensorType inputType = input.getType().dyn_cast<RankedTensorType>();
Expand All @@ -464,13 +466,18 @@ class ConvertAtenAbsOp : public OpConversionPattern<AtenAbsOp> {
return rewriter.notifyMatchFailure(
op, "Abs input tensor must have integer or floating-point datatype");

rewriter.replaceOpWithNewOp<tcp::AbsOp>(op, inputType, input);
RankedTensorType resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();

rewriter.replaceOpWithNewOp<TcpOpT>(op, resultType, input);
return success();
}
};

template <typename AtenOpT, typename TcpOpT>
class ConvertAtenUnaryOp : public OpConversionPattern<AtenOpT> {
class ConvertAtenUnaryFpOnlyOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
Expand Down Expand Up @@ -680,37 +687,38 @@ void torch_to_tcp::populateElementwisePatternsAndLegality(

target.addIllegalOp<AtenCeilOp>();
target.addIllegalOp<AtenFloorOp>();
target.addIllegalOp<AtenSqrtOp>();
target.addIllegalOp<AtenSigmoidOp>();
target.addIllegalOp<AtenTanhOp>();
target.addIllegalOp<AtenSinOp>();
target.addIllegalOp<AtenCosOp>();
target.addIllegalOp<AtenLogOp>();
target.addIllegalOp<AtenNegOp>();
target.addIllegalOp<AtenAtanOp>();
patterns.add<ConvertAtenUnaryOp<AtenFloorOp, tcp::FloorOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenCeilOp, tcp::CeilOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSqrtOp, tcp::SqrtOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSigmoidOp, tcp::SigmoidOp>>(typeConverter,
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenFloorOp, tcp::FloorOp>>(
typeConverter, context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenCeilOp, tcp::CeilOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenSigmoidOp, tcp::SigmoidOp>>(
typeConverter, context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenTanhOp, tcp::TanhOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenSinOp, tcp::SinOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenCosOp, tcp::CosOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenLogOp, tcp::LogOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenNegOp, tcp::NegOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryFpOnlyOp<AtenAtanOp, tcp::AtanOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenTanhOp, tcp::TanhOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenSinOp, tcp::SinOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenCosOp, tcp::CosOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenLogOp, tcp::LogOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenNegOp, tcp::NegOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryOp<AtenAtanOp, tcp::AtanOp>>(typeConverter,
context);

target.addIllegalOp<AtenAbsOp>();
patterns.add<ConvertAtenAbsOp>(typeConverter, context);
target.addIllegalOp<AtenSqrtOp>();
patterns.add<ConvertAtenUnaryIntOrFpOp<AtenAbsOp, tcp::AbsOp>>(typeConverter,
context);
patterns.add<ConvertAtenUnaryIntOrFpOp<AtenSqrtOp, tcp::SqrtOp>>(
typeConverter, context);

target.addIllegalOp<AtenBatchNormOp>();
patterns.add<ConvertAtenBatchNormOp>(typeConverter, context);
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchToTcp/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,19 @@ func.func @torch.aten.sqrt(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[

// -----

// CHECK-LABEL: func.func @torch.aten.sqrt_int(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T1:.*]] = tcp.sqrt %[[T0]] : tensor<?x?xi32> -> tensor<?x?xf32>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.sqrt_int(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.sqrt %arg0 : !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.ceil(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down

0 comments on commit 0ea235e

Please sign in to comment.