|
10 | 10 | #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
|
11 | 11 |
|
12 | 12 | #include "gc/Dialect/Linalgx/LinalgxOps.h"
|
| 13 | +#include "gc/Dialect/Linalgx/Utils.h" |
13 | 14 | #include "mlir/Dialect/DLTI/DLTI.h"
|
14 | 15 | #include "mlir/Dialect/Linalg/IR/Linalg.h"
|
15 | 16 | #include "mlir/Interfaces/DataLayoutInterfaces.h"
|
@@ -55,13 +56,17 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
|
55 | 56 | SmallVector<DimType>{DimType::M, DimType::K},
|
56 | 57 | SmallVector<DimType>{DimType::K, DimType::N},
|
57 | 58 | SmallVector<DimType>{DimType::M, DimType::N}};
|
58 |
| - } else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) { |
| 59 | + } else if (linalgx::isGenericPackedMatmulOp( |
| 60 | + linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D) || |
| 61 | + llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) { |
59 | 62 | return SmallVector<SmallVector<DimType>>{
|
60 | 63 | SmallVector<DimType>{DimType::M, DimType::K},
|
61 | 64 | SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
|
62 | 65 | DimType::K},
|
63 | 66 | SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
|
64 |
| - } else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) { |
| 67 | + } else if (linalgx::isGenericPackedMatmulOp( |
| 68 | + linalgOp.getOperation(), linalgx::PackingType::VNNI_MM4D) || |
| 69 | + llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) { |
65 | 70 | return SmallVector<SmallVector<DimType>>{
|
66 | 71 | SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
|
67 | 72 | SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
|
@@ -92,18 +97,34 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
|
92 | 97 | SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
|
93 | 98 | SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
|
94 | 99 | SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
|
| 100 | + } else if (linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), |
| 101 | + linalgx::PackingType::MM4D)) { |
| 102 | + return SmallVector<SmallVector<DimType>>{ |
| 103 | + SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K}, |
| 104 | + SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N}, |
| 105 | + SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}}; |
95 | 106 | }
|
96 | 107 | return failure();
|
97 | 108 | }
|
98 | 109 |
|
99 | 110 | // The analysis to extract the matmul configuration from the given linalg op
|
100 | 111 | struct MatmulConfigAnalysis {
|
101 | 112 | public:
|
102 |
| - explicit MatmulConfigAnalysis(Operation *root); |
103 |
| - MatmulConfig getConfig() { return config; } |
| 113 | + // Extract the matmul configuration from the given linalg op |
| 114 | + MatmulConfigAnalysis(Operation *root) : root(root){}; |
| 115 | + |
| 116 | + // Get the matmul configuration |
| 117 | + MatmulConfig getConfig(); |
| 118 | + |
| 119 | + void setAllowIndivisibleInnerBlock(bool allow) { |
| 120 | + allowIndivisibleInnerBlock = allow; |
| 121 | + } |
104 | 122 |
|
105 | 123 | private:
|
106 |
| - MatmulConfig config; |
| 124 | + MatmulConfig config = MatmulConfig{1, 1, 1, 1, 1, 1, 1, 1, 1}; |
| 125 | + Operation *root; |
| 126 | + bool hasConfig = false; |
| 127 | + bool allowIndivisibleInnerBlock = true; |
107 | 128 | };
|
108 | 129 |
|
109 | 130 | } // namespace gc
|
|
0 commit comments