Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseFloorModule_basic",
"ElementwiseFloorIntModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ElementwiseMinimumModule_basic",
Expand Down
91 changes: 46 additions & 45 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1023,51 +1023,6 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [
}];
}

def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFloorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFloor_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -3657,6 +3612,52 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [
}];
}

def Torch_AtenFloorOp : Torch_Op<"aten.floor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::floor : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFloorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::floor_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloor_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFloor_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,22 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenFloorOp
//===----------------------------------------------------------------------===//
void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) {
auto outputTy = op.getType().dyn_cast<ValueTensorType>();
if (outputTy && outputTy.hasDtype() &&
outputTy.getDtype().isa<mlir::IntegerType>()) {
rewriter.replaceOp(op, op.getSelf());
return success();
}
return failure();
});
}

//===----------------------------------------------------------------------===//
// AtenMulScalarOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::atan : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",
"aten::ceil : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
Expand Down Expand Up @@ -333,6 +332,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)

emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
18 changes: 18 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,24 @@ def forward(self, a):
def ElementwiseFloorModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))

class ElementwiseFloorIntModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.floor(a)


@register_test_case(module_factory=lambda: ElementwiseFloorIntModule())
def ElementwiseFloorIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10).to(torch.int32))


# ==============================================================================

Expand Down