Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8519,6 +8519,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [
}];
}

def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIsneginfOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenIsposinfOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAllOp : Torch_Op<"aten.all", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -10449,6 +10495,32 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
}];
}

def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalFloatType:$nan,
AnyTorchOptionalFloatType:$posinf,
AnyTorchOptionalFloatType:$neginf
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenNanToNumOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
AllowsTypeRefinement,
ReadOnly
Expand Down
62 changes: 62 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6702,6 +6702,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -7874,6 +7882,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
Expand Down Expand Up @@ -9710,6 +9722,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int9 = torch.constant.int 9\n"
" %int10 = torch.constant.int 10\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %int9 = torch.constant.int 9\n"
" %int10 = torch.constant.int 10\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %3 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
Expand Down Expand Up @@ -10713,6 +10771,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
80 changes: 80 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern<AtenIsinfOp> {
};
} // namespace

namespace {
class DecomposeAtenIsneginfOp : public OpRewritePattern<AtenIsneginfOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIsneginfOp op,
PatternRewriter &rewriter) const override {
mlir::FloatType f64Type = rewriter.getF64Type();
Value inf = rewriter.create<ConstantFloatOp>(
op.getLoc(),
rewriter.getFloatAttr(
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
inf);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIsposinfOp op,
PatternRewriter &rewriter) const override {
mlir::FloatType f64Type = rewriter.getF64Type();
Value inf = rewriter.create<ConstantFloatOp>(
op.getLoc(),
rewriter.getFloatAttr(f64Type,
APFloat::getInf(f64Type.getFloatSemantics())));
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
inf);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
Expand Down Expand Up @@ -2471,6 +2505,49 @@ class DecomposeAtenWhereScalarSelfOp
};
} // namespace

namespace {
class DecomposeAtenNanToNumOp : public OpRewritePattern<AtenNanToNumOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNanToNumOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
mlir::FloatType f64Type = rewriter.getF64Type();
Value nan = op.getNan();
Value posinf = op.getPosinf();
Value neginf = op.getNeginf();
auto baseType =
ValueTensorType::getWithLeastStaticInformation(op.getContext());
if (dyn_cast_or_null<ConstantNoneOp>(nan.getDefiningOp()))
nan = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(
f64Type, APFloat::getZero(f64Type.getFloatSemantics())));
if (dyn_cast_or_null<ConstantNoneOp>(posinf.getDefiningOp()))
posinf = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
if (dyn_cast_or_null<ConstantNoneOp>(neginf.getDefiningOp()))
neginf = rewriter.create<ConstantFloatOp>(
loc,
rewriter.getFloatAttr(
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
Value isNan =
rewriter.create<Torch::AtenIsnanOp>(loc, baseType, op.getSelf());
Value where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
loc, baseType, isNan, nan, op.getSelf());
Value isposinf =
rewriter.create<Torch::AtenIsposinfOp>(loc, baseType, where);
where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
loc, baseType, isposinf, posinf, where);
Value isneginf =
rewriter.create<Torch::AtenIsneginfOp>(loc, baseType, where);
rewriter.replaceOpWithNewOp<Torch::AtenWhereScalarSelfOp>(
op, op.getType(), isneginf, neginf, where);
return success();
}
};
} // namespace

// Decompose aten.masked_fill.Scalar into aten.where.self op.
namespace {
class DecomposeAtenMaskedFillScalarOp
Expand Down Expand Up @@ -6393,6 +6470,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
Expand Down Expand Up @@ -6448,6 +6526,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenEyeOp>();
target.addIllegalOp<AtenEyeMOp>();
target.addIllegalOp<AtenNanToNumOp>();
target.addIllegalOp<AtenIsnanOp>();
target.addIllegalOp<AtenIsinfOp>();
target.addIllegalOp<AtenIsneginfOp>();
target.addIllegalOp<AtenIsposinfOp>();
target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
Expand Down
7 changes: 7 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"ElementwiseWhereScalarSelfStaticModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseNotInt32Module_basic",
Expand Down Expand Up @@ -1039,6 +1040,8 @@
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenIsinfOpModule_basic",
"ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
Expand Down Expand Up @@ -1090,6 +1093,8 @@
"ElementwiseGtIntTensorModule_basic",
"ElementwiseGtMixed2ScalarModule_basic",
"ElementwiseIsinfModule_basic",
"ElementwiseAtenIsneginfOpModule_basic",
"ElementwiseAtenIsposinfOpModule_basic",
"ElementwiseIsnanModule_basic",
"ElementwiseLeFloatTensorModule_basic",
"ElementwiseLeIntTensorModule_basic",
Expand Down Expand Up @@ -1146,6 +1151,7 @@
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseWhereScalarModule_basic",
"ElementwiseNanToNumModule_Basic",
"EmbeddingModule1DIndices_basic",
"EmbeddingModuleI32Static_basic",
"FlattenRank0Module_basic",
Expand Down Expand Up @@ -1511,6 +1517,7 @@
"ElementwiseBitwiseAndScalarInt64Module_basic",
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseDequantizePerTensorModule_basic"
}
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]:
def aten〇isinf〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇isneginf〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇isposinf〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

Expand Down Expand Up @@ -1058,6 +1064,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(condition, other)

def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight))

Expand Down Expand Up @@ -2516,6 +2525,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64}))
def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.complex128 and self_dtype != torch.complex64
return torch.bool

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64}))
def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.complex128 and self_dtype != torch.complex64
return torch.bool

@check_dtype_function(_check_two_tensor_op())
def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
return torch.bool
Expand Down Expand Up @@ -3247,6 +3270,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
dtypes = [get_dtype_of_scalar(self), other_dtype]
return promote_dtypes(ranks, dtypes)

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(
[Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64),
TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0),
Expand Down
Loading