Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -931,4 +931,33 @@ def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
}

// To handle runtime assertions, torchscript provides us `torch._assert` operation.
// But TS compiler introduces control flow for `torch._assert` operation. The
// `torch._assert` would introduce control flow like:
//
// %cond = "torch.aten.Bool.Tensor"(%0) : (!torch.tensor) -> !torch.bool
// "torch.prim.If"(%cond) ({
// "torch.prim.If.yield"() : () -> ()
// }, {
// "torch.prim.RaiseException"(%msg) : (!torch.str) -> ()
// "torch.prim.If.yield"() : () -> ()
// }) : (!torch.bool) -> ()
//
// This new operation `torch.runtime.assert` is added to simplify the IR control
// flow by avoiding unnecessary branches. It also makes insertion of the runtime
// assert in the source code easier.
def Torch_RuntimeAssertOp: Torch_Op<"runtime.assert", [
AllowsTypeRefinement,
HasValueSemantics,
]> {
let summary = "Runtime Assertion";
let arguments = (ins
Torch_BoolType:$condition,
StrAttr:$message
);
let results = (outs
);
let assemblyFormat = "$condition `,` $message attr-dict";
}

#endif // TORCH_OPS
16 changes: 16 additions & 0 deletions lib/Conversion/TorchToStd/TorchToStd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
};
} // namespace

namespace {
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AssertOp>(op, adaptor.condition(),
adaptor.message());
return success();
}
};
} // namespace

namespace {
template <typename AtenOp, typename BinOp>
class ConvertAtenBinaryOp : public OpConversionPattern<AtenOp> {
Expand Down Expand Up @@ -173,6 +187,8 @@ class ConvertTorchToStd : public ConvertTorchToStdBase<ConvertTorchToStd> {
RewritePatternSet patterns(context);
target.addIllegalOp<AtenDimOp>();
patterns.add<ConvertAtenDimOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
target.addIllegalOp<AtenNeIntOp>();
patterns.add<ConvertAtenNeIntOp>(typeConverter, context);
target.addIllegalOp<AtenGtIntOp>();
Expand Down
14 changes: 14 additions & 0 deletions test/Conversion/TorchToStd/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int {
return %0 : !torch.int
}

// CHECK-LABEL: func @torch.runtime.assert(
// CHECK-SAME: %[[X:.*]]: !torch.int,
// CHECK-SAME: %[[Y:.*]]: !torch.int) {
// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]]
// CHECK: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]]
// CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64
// CHECK: assert %[[CMP]], "x must not be equal to y"
// CHECK: return
func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) {
%0 = torch.aten.ne.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool
torch.runtime.assert %0, "x must not be equal to y"
return
}

// CHECK-LABEL: func @torch.aten.ne.int(
// CHECK-SAME: %[[LHS:.*]]: !torch.int,
// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool {
Expand Down