-
Notifications
You must be signed in to change notification settings - Fork 10.8k
/
VectorDistribution.h
84 lines (71 loc) · 3.24 KB
/
VectorDistribution.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
//===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
#define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
#include "mlir/Dialect/Vector/IR/VectorOps.h"
namespace mlir {
class RewritePatternSet;
namespace vector {
struct WarpExecuteOnLane0LoweringOptions {
/// Lamdba function to let users allocate memory needed for the lowering of
/// WarpExecuteOnLane0Op.
/// The function needs to return an allocation that the lowering can use as
/// temporary memory. The allocation needs to match the shape of the type (the
/// type may be VectorType or a scalar) and be availble for the current warp.
/// If there are several warps running in parallel the allocation needs to be
/// split so that each warp has its own allocation.
using WarpAllocationFn =
std::function<Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)>;
WarpAllocationFn warpAllocationFn = nullptr;
/// Lamdba function to let user emit operation to syncronize all the thread
/// within a warp. After this operation all the threads can see any memory
/// written before the operation.
using WarpSyncronizationFn =
std::function<void(Location, OpBuilder &, WarpExecuteOnLane0Op)>;
WarpSyncronizationFn warpSyncronizationFn = nullptr;
};
void populateWarpExecuteOnLane0OpToScfForPattern(
RewritePatternSet &patterns,
const WarpExecuteOnLane0LoweringOptions &options);
using DistributionMapFn = std::function<AffineMap(vector::TransferWriteOp)>;
/// Distribute transfer_write ops based on the affine map returned by
/// `distributionMapFn`.
/// Example:
/// ```
/// %0 = vector.warp_execute_on_lane_0(%id){
/// ...
/// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
/// vector.yield
/// }
/// ```
/// To
/// ```
/// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
/// ...
/// vector.yield %v : vector<32xf32>
/// }
/// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
void populateDistributeTransferWriteOpPatterns(
RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn);
/// Move scalar operations with no dependency on the warp op outside of the
/// region.
void moveScalarUniformCode(WarpExecuteOnLane0Op op);
/// Collect patterns to propagate warp distribution.
void populatePropagateWarpVectorDistributionPatterns(
RewritePatternSet &pattern);
/// Lambda signature to compute a reduction of a distributed value for the given
/// reduction kind and size.
using DistributedReductionFn =
std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>;
/// Collect patterns to distribute vector reduction ops using given lamdba to
/// distribute reduction op.
void populateDistributeReduction(RewritePatternSet &pattern,
DistributedReductionFn distributedReductionFn);
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_