From 1dd89d4a8fd050bc95e4e6717ded4210a51715b2 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Mon, 6 Jan 2020 16:27:51 -0800 Subject: [PATCH] Lower tf.OutfeedEnqueueTuple op to XLA HLO. OutfeedEnqueueTuple is lowered to HLO tuple, after_all and outfeed ops. after_all op is emitted to generate XLA token required by outfeed op. PiperOrigin-RevId: 288399352 Change-Id: If56424b044e631f64837b39c8758dce9999ed4ab --- .../mlir/tensorflow/ir/tf_generated_ops.td | 15 +++++++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 14 +++++++ .../mlir/xla/transforms/legalize_tf.cc | 39 ++++++++++++++++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 9b3d749864c844..bc8b18671c9f72 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -3919,6 +3919,21 @@ output = }]; } +def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { + let summary = "Enqueue multiple Tensor values on the computation outfeed."; + + let description = [{ + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr dtypes = TF_DerivedOperandTypeListAttr<0>; +} + def TF_PackOp : TF_Op<"Pack", [NoSideEffect]> { let summary = [{ Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 7e743cacb2b1e3..da1dfbb9efe9c0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1253,6 +1253,20 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tenso return %result : tensor<3x5xf32> } +//===----------------------------------------------------------------------===// +// tf.OutfeedEnqueueTuple legalization +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @outfeed_enqueue_tuple +// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) +func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { +// CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple, tensor<4xf32>> +// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token +// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[AFTER_ALL]]) {outfeed_config = ""} : (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token + "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () + return +} + //===----------------------------------------------------------------------===// // Pack op legalizations. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 9c58b242460dc5..0c91c75c3b0307 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -2480,6 +2480,41 @@ class ConvertOneHotOp : public OpRewritePattern { } }; +// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops. +// +// XLA HLO outfeed op expects a token, which we generate by emitting an +// after_all op. +// +// For example the following IR: +// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> +// () +// +// would be lowered to +// +// %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> +// tuple, tensor<4xf32>> +// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token +// %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} : +// (tuple, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token +// +class ConvertOutfeedEnqueueTupleOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, + PatternRewriter &rewriter) const override { + auto token_type = xla_hlo::TokenType::get(rewriter.getContext()); + auto tuple = rewriter.create(op.getLoc(), op.inputs()); + auto afterall = + rewriter.create(op.getLoc(), token_type, ValueRange()); + rewriter.create(op.getLoc(), token_type, tuple, afterall, + /*outfeed_config=*/rewriter.getStringAttr("")); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + // Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant. // // tf.TopKV2 sorts along last dimension of the input tensor and then returns @@ -2770,8 +2805,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) { ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertMaxOp, ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp, - ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp, - ConvertSoftmaxOp, + ConvertOutfeedEnqueueTupleOp, ConvertRangeOp, ConvertSigmoidOp, + ConvertSizeOp, ConvertSoftmaxOp, ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,