Skip to content

Commit

Permalink
[mlir][spirv] Legalize subviewop when used with vector transfer
Browse files Browse the repository at this point in the history
Subview operations are not natively supported downstream in the spirv path.
This change allows removing subview when used by vector transfer the same way
we already do it when they are used by LoadOp/StoreOp

Differential Revision: https://reviews.llvm.org/D82106
  • Loading branch information
ThomasRaoux committed Jun 20, 2020
1 parent e4bc08f commit 670455c
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 19 deletions.
88 changes: 69 additions & 19 deletions mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
Expand Up @@ -15,28 +15,41 @@
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"

using namespace mlir;

namespace {
/// Merges subview operation with load operation.
class LoadOpOfSubViewFolder final : public OpRewritePattern<LoadOp> {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<LoadOp>::OpRewritePattern;
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(LoadOp loadOp,
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;

private:
void replaceOp(OpTy loadOp, SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};

/// Merges subview operation with store operation.
class StoreOpOfSubViewFolder final : public OpRewritePattern<StoreOp> {
/// Merges subview operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<StoreOp>::OpRewritePattern;
using OpRewritePattern<OpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(StoreOp storeOp,
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;

private:
void replaceOp(OpTy StoreOp, SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};
} // namespace

Expand Down Expand Up @@ -85,13 +98,14 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
}

//===----------------------------------------------------------------------===//
// Folding SubViewOp and LoadOp.
// Folding SubViewOp and LoadOp/TransferReadOp.
//===----------------------------------------------------------------------===//

template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = loadOp.memref().getDefiningOp<SubViewOp>();
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp = loadOp.memref().template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
Expand All @@ -100,19 +114,36 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
loadOp.indices(), sourceIndices)))
return failure();

replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
return success();
}

template <>
void LoadOpOfSubViewFolder<LoadOp>::replaceOp(LoadOp loadOp,
SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
return success();
}

template <>
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
vector::TransferReadOp loadOp, SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices);
}

//===----------------------------------------------------------------------===//
// Folding SubViewOp and StoreOp.
// Folding SubViewOp and StoreOp/TransferWriteOp.
//===----------------------------------------------------------------------===//

template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
PatternRewriter &rewriter) const {
auto subViewOp = storeOp.memref().getDefiningOp<SubViewOp>();
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp = storeOp.memref().template getDefiningOp<SubViewOp>();
if (!subViewOp) {
return failure();
}
Expand All @@ -121,9 +152,25 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
storeOp.indices(), sourceIndices)))
return failure();

replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
return success();
}

template <>
void StoreOpOfSubViewFolder<StoreOp>::replaceOp(
StoreOp storeOp, SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<StoreOp>(storeOp, storeOp.value(),
subViewOp.source(), sourceIndices);
return success();
}

template <>
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
vector::TransferWriteOp tranferWriteOp, SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
tranferWriteOp, tranferWriteOp.vector(), subViewOp.source(),
sourceIndices);
}

//===----------------------------------------------------------------------===//
Expand All @@ -132,7 +179,10 @@ StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,

void mlir::populateStdLegalizationPatternsForSPIRVLowering(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<LoadOpOfSubViewFolder, StoreOpOfSubViewFolder>(context);
patterns.insert<LoadOpOfSubViewFolder<LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<StoreOp>,
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/test/Conversion/StandardToSPIRV/legalization.mlir
Expand Up @@ -62,3 +62,37 @@ func @fold_dynamic_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 :
store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]>
return
}

// CHECK-LABEL: @fold_static_stride_subview_with_transfer_read
// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index
func @fold_static_stride_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> vector<4xf32> {
// CHECK-NOT: subview
// CHECK: [[C2:%.*]] = constant 2 : index
// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
// CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: vector.transfer_read [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%f0 = constant 0.0 : f32
%0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
%1 = vector.transfer_read %0[%arg3, %arg4], %f0 : memref<4x4xf32, offset:?, strides: [64, 3]>, vector<4xf32>
return %1 : vector<4xf32>
}

// CHECK-LABEL: @fold_static_stride_subview_with_transfer_write
// CHECK-SAME: [[ARG0:%.*]]: memref<12x32xf32>, [[ARG1:%.*]]: index, [[ARG2:%.*]]: index, [[ARG3:%.*]]: index, [[ARG4:%.*]]: index, [[ARG5:%.*]]: vector<4xf32>
func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : vector<4xf32>) {
// CHECK-NOT: subview
// CHECK: [[C2:%.*]] = constant 2 : index
// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[STRIDE1:%.*]] = muli [[ARG3]], [[C2]] : index
// CHECK: [[INDEX1:%.*]] = addi [[ARG1]], [[STRIDE1]] : index
// CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index
// CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index
// CHECK: vector.transfer_write [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}}
%0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] :
memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
vector.transfer_write %arg5, %0[%arg3, %arg4] : vector<4xf32>, memref<4x4xf32, offset:?, strides: [64, 3]>
return
}

0 comments on commit 670455c

Please sign in to comment.