From 37782f76f5d7a91a59bb8dfdb358da0841c40afa Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 2 May 2024 20:53:13 +0530 Subject: [PATCH] [torch] Add Canonicalize Pattern for embedding op Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 27 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 21 +++++++++++++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 95d92af992b..d3da5ef9144 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10004,6 +10004,7 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1d0ff41f784..62041d2a1b5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -547,6 +547,33 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return success(); }); } +//===----------------------------------------------------------------------===// +// AtenEmbeddingOp +//===----------------------------------------------------------------------===// +// +// Converts PrimConvertElementTypeOp followed by AtenEmbeddingOp to +// AtenEmbeddingOp followed by PrimConvertElementTypeOp. +void AtenEmbeddingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenEmbeddingOp op, PatternRewriter &rewriter) { + auto convertOp = + op.getWeight().getDefiningOp(); + if (!convertOp) + return failure(); + Value getWeight = convertOp.getA(); + auto weightType = getWeight.getType().cast(); + auto opType = op.getType().cast(); + auto updateType = + opType.getWithSizesAndDtype(opType.getSizes(), weightType.getDtype()); + Value newEmbedding = rewriter.create( + op.getLoc(), updateType, getWeight, op.getIndices(), op.getPaddingIdx(), + op.getScaleGradByFreq(), op.getSparse()); + Value updateRes = rewriter.create( + op.getLoc(), op.getType(), newEmbedding, convertOp.getDtype()); + rewriter.replaceOp(op, updateRes); + return success(); + }); +} //===----------------------------------------------------------------------===// // RuntimeAssertOp diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index d4d547456c4..b9a083e3114 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -771,7 +771,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True) emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True) emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True) - emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") + emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)", has_canonicalizer=True) emit( "aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)" ) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index e7605f66169..151e97f3c14 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3015,3 +3015,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> return %result0 : !torch.vtensor<[10,64,56,56],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[64,128],f16>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,128],f32> { +// CHECK: %[[INT32:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NEGONE:.*]] = torch.constant.int -1 +// CHECK: %[[EMBD:.*]] = torch.aten.embedding %[[ARG0]], %[[ARG1]], %[[NEGONE]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[64,128],f16>, !torch.vtensor<[4],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,128],f16> +// CHECK: %[[DTYPE:.*]] = torch.prims.convert_element_type %[[EMBD]], %[[INT32]] : !torch.vtensor<[4,128],f16>, !torch.int -> !torch.vtensor<[4,128],f32> +// CHECK: return %[[DTYPE]] : !torch.vtensor<[4,128],f32> +// CHECK: } +func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64,128],f16>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,128],f32> { + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %int-1 = torch.constant.int -1 + %0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64,128],f16>, !torch.int -> !torch.vtensor<[64,128],f32> + %1 = torch.aten.embedding %0, %arg1, %int-1, %false, %false : !torch.vtensor<[64,128],f32>, !torch.vtensor<[4],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,128],f32> + return %1 : !torch.vtensor<[4,128],f32> +}