Skip to content

Commit

Permalink
[mlir][Linalg] Add a test for a fused Linalg pass based on DRR to go …
Browse files Browse the repository at this point in the history
…from matmul to vectors

This revision builds a simple "fused pass" consisting of 2 levels of tiling, memory promotion and vectorization using linalg transformations written as composable pattern rewrites.
  • Loading branch information
Nicolas Vasilache committed Apr 8, 2020
1 parent c6e917d commit 6fb6a4d
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Linalg/matmul-to-vector.mlir
@@ -0,0 +1,16 @@
// RUN: mlir-opt %s -linalg-matmul-to-vector | FileCheck %s

func @matmul_perm(%A: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%B: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
%C: memref<1584x1584xf32, offset: 0, strides: [1584, 1]>) {
linalg.matmul(%A, %B, %C) {__internal_linalg_transform__ = "__with_perm__"} :
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>,
memref<1584x1584xf32, offset: 0, strides: [1584, 1]>
return
}

// CHECK-LABEL:func @matmul_perm
// CHECK: vector.contract
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: : vector<8x16xf32>, vector<16x12xf32> into vector<8x12xf32>
4 changes: 4 additions & 0 deletions mlir/test/lib/DeclarativeTransforms/CMakeLists.txt
Expand Up @@ -5,3 +5,7 @@ add_public_tablegen_target(MLIRTestLinalgTransformPatternsIncGen)
set(LLVM_TARGET_DEFINITIONS TestVectorTransformPatterns.td)
mlir_tablegen(TestVectorTransformPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestVectorTransformPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS TestLinalgMatmulToVectorPatterns.td)
mlir_tablegen(TestLinalgMatmulToVectorPatterns.h.inc -gen-rewriters)
add_public_tablegen_target(MLIRTestLinalgMatmulToVectorPatternsIncGen)
@@ -0,0 +1,43 @@
//===- TestLinalgMatmulToVectorPatterns.td - Test patterns -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the pattern definition file for declarative Linalg transformations
// tests.
//
//===----------------------------------------------------------------------===//

#ifndef TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
#define TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS

include "mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td"
include "mlir/Dialect/Vector/VectorTransformPatterns.td"

//===----------------------------------------------------------------------===//
// Linalg tiling and permutation patterns.
//===----------------------------------------------------------------------===//
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[768, 264, 768], "L2__with_perm__", [1, 2, 0]>),
[(Constraint<HasLinalgTransformMarker<"__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(TileLinalgOp<[8, 12, 16], "L1__with_perm__", [1, 0, 2]>),
[(Constraint<HasLinalgTransformMarker<"L2__with_perm__">>)]>;
def : Pat<(MatmulOp:$op $_, $_, $_),
(PromoteSubviewsLinalgOp),
[(Constraint<HasOperandsOfType<"SubViewOp">>),
(Constraint<HasLinalgTransformMarker<"L1__with_perm__">>)]>;

//===----------------------------------------------------------------------===//
// Linalg to vector contraction patterns.
//===----------------------------------------------------------------------===//
def : Pattern<(MatmulOp:$op $_, $_, $_),
[(VectorizeLinalgOp)],
[(Constraint<And<[
HasLinalgTransformMarker<"L1__with_perm__">,
PreconditionVectorizeLinalgOp]>>)]>;

#endif // TEST_LINALG_MATMUL_TO_VECTOR_PATTERNS
2 changes: 2 additions & 0 deletions mlir/test/lib/Transforms/CMakeLists.txt
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
TestGpuMemoryPromotion.cpp
TestGpuParallelLoopMapping.cpp
TestInlining.cpp
TestLinalgMatmulToVector.cpp
TestLinalgTransforms.cpp
TestLiveness.cpp
TestLoopMapping.cpp
Expand All @@ -24,6 +25,7 @@ add_llvm_library(MLIRTestTransforms

DEPENDS
MLIRStandardOpsIncGen
MLIRTestLinalgMatmulToVectorPatternsIncGen
MLIRTestLinalgTransformPatternsIncGen
MLIRTestVectorTransformPatternsIncGen
)
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/lib/Transforms/TestLinalgMatmulToVector.cpp
@@ -0,0 +1,51 @@
//===- TestLinalgMatmulToVector.cpp - Test VectorTransfers lowering -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <type_traits>

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;

namespace {
#include "TestLinalgMatmulToVectorPatterns.h.inc"

struct DeclarativeTransforms
: public PassWrapper<DeclarativeTransforms, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto *context = &getContext();
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
AffineMinOp::getCanonicalizationPatterns(patterns, context);
AffineMaxOp::getCanonicalizationPatterns(patterns, context);
AllocOp::getCanonicalizationPatterns(patterns, context);
SubViewOp::getCanonicalizationPatterns(patterns, context);
ViewOp::getCanonicalizationPatterns(patterns, context);
populateWithGenerated(context, &patterns);
applyPatternsGreedily(getFunction(), patterns);
}
};
} // end anonymous namespace

namespace mlir {
void registerTestLinalgMatmulToVectorPass() {
PassRegistration<DeclarativeTransforms> pass(
"linalg-matmul-to-vector",
"Test declarative transform patterns for matmul 3-D tiling + promotion"
" + vectorization");
}
} // namespace mlir
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Expand Up @@ -39,6 +39,7 @@ void registerSimpleParametricTilingPass();
void registerSymbolTestPasses();
void registerTestAffineDataCopyPass();
void registerTestAllReduceLoweringPass();
void registerTestLinalgMatmulToVectorPass();
void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
Expand Down Expand Up @@ -101,6 +102,7 @@ void registerTestPasses() {
registerSymbolTestPasses();
registerTestAffineDataCopyPass();
registerTestAllReduceLoweringPass();
registerTestLinalgMatmulToVectorPass();
registerTestLoopPermutationPass();
registerTestCallGraphPass();
registerTestConstantFold();
Expand Down

0 comments on commit 6fb6a4d

Please sign in to comment.