Skip to content

Commit

Permalink
[torch] Add Canonicalize Pattern for embedding op
Browse files Browse the repository at this point in the history
Converts PrimConvertOp followed by Embedding -> Embedding followed by
PrimConvertOp. We don't need to cast the entire matrix; just the output
of the embedding op.
  • Loading branch information
pashu123 committed May 2, 2024
1 parent 11cd7cd commit 37782f7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10004,6 +10004,7 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [
Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,33 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenEmbeddingOp
//===----------------------------------------------------------------------===//
//
// Converts PrimConvertElementTypeOp followed by AtenEmbeddingOp to
// AtenEmbeddingOp followed by PrimConvertElementTypeOp.
void AtenEmbeddingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenEmbeddingOp op, PatternRewriter &rewriter) {
auto convertOp =
op.getWeight().getDefiningOp<Torch::PrimsConvertElementTypeOp>();
if (!convertOp)
return failure();
Value getWeight = convertOp.getA();
auto weightType = getWeight.getType().cast<BaseTensorType>();
auto opType = op.getType().cast<BaseTensorType>();
auto updateType =
opType.getWithSizesAndDtype(opType.getSizes(), weightType.getDtype());
Value newEmbedding = rewriter.create<AtenEmbeddingOp>(
op.getLoc(), updateType, getWeight, op.getIndices(), op.getPaddingIdx(),
op.getScaleGradByFreq(), op.getSparse());
Value updateRes = rewriter.create<Torch::PrimsConvertElementTypeOp>(
op.getLoc(), op.getType(), newEmbedding, convertOp.getDtype());
rewriter.replaceOp(op, updateRes);
return success();
});
}

//===----------------------------------------------------------------------===//
// RuntimeAssertOp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::detach : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::device.with_index : (str, int) -> (Device)", has_canonicalizer=True)
emit("aten::cuda : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)", has_canonicalizer=True)
emit(
"aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)"
)
Expand Down
21 changes: 21 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3015,3 +3015,24 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor
%result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64>
return %result0 : !torch.vtensor<[10,64,56,56],f32>
}

// -----

// CHECK-LABEL: func.func @torch.prims.convert_element_type$fold(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[64,128],f16>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,128],f32> {
// CHECK: %[[INT32:.*]] = torch.constant.int 6
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NEGONE:.*]] = torch.constant.int -1
// CHECK: %[[EMBD:.*]] = torch.aten.embedding %[[ARG0]], %[[ARG1]], %[[NEGONE]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[64,128],f16>, !torch.vtensor<[4],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,128],f16>
// CHECK: %[[DTYPE:.*]] = torch.prims.convert_element_type %[[EMBD]], %[[INT32]] : !torch.vtensor<[4,128],f16>, !torch.int -> !torch.vtensor<[4,128],f32>
// CHECK: return %[[DTYPE]] : !torch.vtensor<[4,128],f32>
// CHECK: }
func.func @torch.prims.convert_element_type$fold(%arg0: !torch.vtensor<[64,128],f16>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,128],f32> {
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%int-1 = torch.constant.int -1
%0 = torch.prims.convert_element_type %arg0, %int6 : !torch.vtensor<[64,128],f16>, !torch.int -> !torch.vtensor<[64,128],f32>
%1 = torch.aten.embedding %0, %arg1, %int-1, %false, %false : !torch.vtensor<[64,128],f32>, !torch.vtensor<[4],si64>, !torch.int, !torch.bool, !torch.bool -> !torch.vtensor<[4,128],f32>
return %1 : !torch.vtensor<[4,128],f32>
}

0 comments on commit 37782f7

Please sign in to comment.