From 53baf69a614458160e2e425ae0d52bc9f3e2dac7 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 24 Jan 2022 18:58:27 -0500 Subject: [PATCH] Minor fixes --- .../Dialect/Torch/IR/GeneratedPrimOps.td | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 14 -------------- .../importer/jit_ir/build_tools/torch_ods_gen.py | 2 +- .../importer/jit_ir/csrc/torch_to_mlir_utils.cpp | 1 + 4 files changed, 3 insertions(+), 16 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td index 57a9f9334465..609c53b20401 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedPrimOps.td @@ -175,6 +175,7 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [ } def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ + NoSideEffect, AllowsTypeRefinement, HasValueSemantics ]> { @@ -185,7 +186,6 @@ def Torch_PrimUninitializedOp : Torch_Op<"prim.Uninitialized", [ AnyTorchType:$result ); let assemblyFormat = " attr-dict `:` qualified(type($result))"; - let hasCanonicalizer = 1; } def Torch_PrimUncheckedCastOp : Torch_Op<"prim.unchecked_cast", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index f43eaefe4bdf..2593a58fa465 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1037,20 +1037,6 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } -//===----------------------------------------------------------------------===// -// PrimUninitializedOp -//===----------------------------------------------------------------------===// - -void PrimUninitializedOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](PrimUninitializedOp op, PatternRewriter &rewriter) { - if (!op.use_empty()) - return failure(); - rewriter.eraseOp(op); - return success(); - }); -} - //===----------------------------------------------------------------------===// // PrimTupleUnpackOp //===----------------------------------------------------------------------===// diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index acd8c0763c6f..06eafc7e19b1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -414,7 +414,7 @@ def emit(key, **kwargs): emit("prim::max.self_int : (int[]) -> (int)") emit("prim::max.int : (int, int) -> (int)") emit("prim::RaiseException : (str) -> ()") - emit("prim::Uninitialized : () -> (Any)", has_canonicalizer=True) + emit("prim::Uninitialized : () -> (Any)", traits=["NoSideEffect"]) emit("prim::unchecked_cast : (t) -> (t)", traits=["DeclareOpInterfaceMethods"]) emit("prim::Print : (...) -> ()") diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp index 2d6b8aabd7fd..30b1cbd2d225 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/torch_to_mlir_utils.cpp @@ -11,6 +11,7 @@ #include "function_importer.h" #include "ivalue_importer.h" +#include #include #include "mlir_utils.h"