Skip to content

Commit

Permalink
Revert "Fold standalone linalg.fill ops into flow.tensor.splat ops"
Browse files Browse the repository at this point in the history
This reverts commit 0cc653f.
  • Loading branch information
antiagainst committed Jun 11, 2021
1 parent 91d46f4 commit d6567bd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,6 @@ namespace Flow {
static unsigned kNumMaxParallelDims = 3;

namespace {

/// Returns the dynamic dimensions of the given `value`, assuming it has a
/// shaped type.
SmallVector<Value, 4> getDynamicDims(OpBuilder &builder, Location loc,
Value value) {
SmallVector<Value, 4> dynamicDims;
for (auto shape : enumerate(value.getType().cast<ShapedType>().getShape())) {
if (shape.value() == ShapedType::kDynamicSize) {
dynamicDims.push_back(
builder.createOrFold<memref::DimOp>(loc, value, shape.index()));
}
}
return dynamicDims;
}

/// PatternRewriter that allows replacing only a subset of uses.
/// Since this only adds a method, it can just be static_cast'ed to when
/// applying a rewrite.
Expand Down Expand Up @@ -906,18 +891,6 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern {
return failure();
}

// If this is a standalone fill op, we don't need to create a dispatch
// region for it; just use flow.tensor.splat so we can leverage DMA
// functionalities.
Location loc = op->getLoc();
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
SmallVector<Value, 4> dynamicDims =
getDynamicDims(rewriter, loc, fillOp.output());
rewriter.replaceOpWithNewOp<TensorSplatOp>(op, fillOp.output().getType(),
fillOp.value(), dynamicDims);
return success();
}

// The workgroup count is based on the result shape.
Optional<SmallVector<SmallVector<Value>>> resultShapesOpt =
getResultShapes(rewriter, op);
Expand All @@ -932,6 +905,7 @@ struct MakeDispatchWorkgroupsOp : public RewritePattern {
// the flow has three elements of workload size (x, y, z) by linearizing the
// workloads for all higher dimensions greater than or equal to
// kNumMaxParallelDims.
Location loc = op->getLoc();
SmallVector<Value, 4> count(resultShapes[0].begin(), resultShapes[0].end());
if (count.size() > kNumMaxParallelDims) {
unsigned numSymbols = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,18 @@ func @fuse_tensor_update_with_fill(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[VAL:.+]] = tensor.extract %[[ARG1]][]
// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[RD0:.+]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG4]], %[[D0]]]
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG5]], %[[D1]]]
// CHECK: %[[RESULT:.+]] = flow.tensor.splat %[[VAL]] : tensor<?x?xf32>{%[[RD0]], %[[RD1]]}
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK-SAME: [%[[RD1]], %[[RD0]], %[[C1]]]
// CHECK-SAME: (%[[ARG1]], %[[RD0]], %[[RD1]])
// CHECK-DAG: %[[VAL:.+]] = tensor.extract
// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor
// CHECK: %[[RETURN:.+]] = linalg.fill(%[[INIT]], %[[VAL]])
// CHECK: flow.dispatch.tensor.store %[[RETURN]], {{.*}}
// CHECK-NEXT: flow.return
// CHECK: flow.tensor.update %[[ARG0]], %[[RESULT]]

// -----
Expand Down Expand Up @@ -486,7 +492,12 @@ func @subtensor_insert(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x225x225x3xf32
// CHECK: func @subtensor_insert
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x224x224x3xf32>)
//
// CHECK: %[[FILL:.+]] = constant dense<0.000000e+00> : tensor<1x225x225x3xf32>
// CHECK: %[[FILL:.+]] = flow.dispatch.workgroups[{{.+}}]() : () -> tensor<1x225x225x3xf32> =
// CHECK-NEXT: (%[[OUTPUT:.+]]: !flow.dispatch.tensor<writeonly:1x225x225x3xf32>) {
// CHECK: linalg.init_tensor
// CHECK-NEXT: %[[TENSOR:.+]] = linalg.fill
// CHECK-NEXT: flow.dispatch.tensor.store %[[TENSOR]], %[[OUTPUT]], {{.*}}
// CHECK-NEXT: flow.return
//
// CHECK: %[[PAD:.+]] = flow.dispatch.workgroups[{{.+}}](%[[INPUT]], %[[FILL]]) : (tensor<1x224x224x3xf32>, tensor<1x225x225x3xf32>) -> %[[FILL]] =
// CHECK-NEXT: (%[[SRC:.+]]: !flow.dispatch.tensor<readonly:1x224x224x3xf32>, %[[DST:.+]]: !flow.dispatch.tensor<readwrite:1x225x225x3xf32>) {
Expand Down

0 comments on commit d6567bd

Please sign in to comment.