Skip to content

Commit

Permalink
Lower tf.OutfeedEnqueueTuple op to XLA HLO.
Browse files Browse the repository at this point in the history
OutfeedEnqueueTuple is lowered to HLO tuple, after_all and outfeed ops. after_all op is emitted to generate XLA token required by outfeed op.

PiperOrigin-RevId: 288399352
Change-Id: If56424b044e631f64837b39c8758dce9999ed4ab
  • Loading branch information
Prakalp Srivastava authored and tensorflower-gardener committed Jan 7, 2020
1 parent 6ebd3bb commit 1dd89d4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
15 changes: 15 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Expand Up @@ -3919,6 +3919,21 @@ output =
}];
}

def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> {
let summary = "Enqueue multiple Tensor values on the computation outfeed.";

let description = [{
}];

let arguments = (ins
Variadic<TF_Tensor>:$inputs
);

let results = (outs);

TF_DerivedOperandTypeListAttr dtypes = TF_DerivedOperandTypeListAttr<0>;
}

def TF_PackOp : TF_Op<"Pack", [NoSideEffect]> {
let summary = [{
Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
Expand Up @@ -1253,6 +1253,20 @@ func @one_hot(%indices: tensor<3xi32>, %on_value: tensor<f32>, %off_value: tenso
return %result : tensor<3x5xf32>
}

//===----------------------------------------------------------------------===//
// tf.OutfeedEnqueueTuple legalization
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @outfeed_enqueue_tuple
// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>)
func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () {
// CHECK: [[TUPLE:%.*]] = "xla_hlo.tuple"([[VAL_0]], [[VAL_1]]) : (tensor<3xi32>, tensor<4xf32>) -> tuple<tensor<3xi32>, tensor<4xf32>>
// CHECK: [[AFTER_ALL:%.*]] = "xla_hlo.after_all"() : () -> !xla_hlo.token
// CHECK: "xla_hlo.outfeed"([[TUPLE]], [[AFTER_ALL]]) {outfeed_config = ""} : (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
"tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> ()
return
}

//===----------------------------------------------------------------------===//
// Pack op legalizations.
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 37 additions & 2 deletions tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
Expand Up @@ -2480,6 +2480,41 @@ class ConvertOneHotOp : public OpRewritePattern<TF::OneHotOp> {
}
};

// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, after_all and outfeed ops.
//
// XLA HLO outfeed op expects a token, which we generate by emitting an
// after_all op.
//
// For example the following IR:
// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
// ()
//
// would be lowered to
//
// %tuple = "xla_hlo.tuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) ->
// tuple<tensor<3xi32>, tensor<4xf32>>
// %token = "xla_hlo.after_all"() : () -> !xla_hlo.token
// %outfeed_token = "xla_hlo.outfeed"(%tuple, %token) {outfeed_config = ""} :
// (tuple<tensor<3xi32>, tensor<4xf32>>, !xla_hlo.token) -> !xla_hlo.token
//
class ConvertOutfeedEnqueueTupleOp
: public OpRewritePattern<TF::OutfeedEnqueueTupleOp> {
public:
using OpRewritePattern::OpRewritePattern;

PatternMatchResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op,
PatternRewriter &rewriter) const override {
auto token_type = xla_hlo::TokenType::get(rewriter.getContext());
auto tuple = rewriter.create<TupleOp>(op.getLoc(), op.inputs());
auto afterall =
rewriter.create<AfterAllOp>(op.getLoc(), token_type, ValueRange());
rewriter.create<OutfeedOp>(op.getLoc(), token_type, tuple, afterall,
/*outfeed_config=*/rewriter.getStringAttr(""));
rewriter.eraseOp(op);
return matchSuccess();
}
};

// Converts tf.TopKV2 to XLA HLO iota, sort, and slice ops when k is a constant.
//
// tf.TopKV2 sorts along last dimension of the input tensor and then returns
Expand Down Expand Up @@ -2770,8 +2805,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion) {
ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op,
ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, ConvertMaxOp,
ConvertMaxPoolOp, ConvertMaxPoolGradOp, ConvertMeanOp, ConvertOneHotOp,
ConvertRangeOp, ConvertSigmoidOp, ConvertSizeOp,
ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertOutfeedEnqueueTupleOp, ConvertRangeOp, ConvertSigmoidOp,
ConvertSizeOp, ConvertSoftmaxOp<TF::LogSoftmaxOp, true>,
ConvertSoftmaxOp<TF::SoftmaxOp, false>, ConvertSplitOp, ConvertSplitVOp,
ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
Expand Down

0 comments on commit 1dd89d4

Please sign in to comment.