Skip to content

Commit

Permalink
[Codegen][GPU] Make operand promotion pattern work with generics (#17650
Browse files Browse the repository at this point in the history
)

The pattern was previously using the `isMatmulOrBatchMatmul` helper that
only looked for named ops. Change the logic to use inferred contraction
dims and look at the static bounds of the op to filter out matvec cases.
  • Loading branch information
qedawkins committed Jun 12, 2024
1 parent abf0087 commit 0a561c4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,42 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
op->setOperand(index, copy.getResult(0));
}

bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
FailureOr<mlir::linalg::ContractionDimensions> contractionDims =
mlir::linalg::inferContractionDims(linalgOp);
if (failed(contractionDims)) {
return false;
}

if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
contractionDims->n.size() < 1) {
return false;
}

auto getElementCount = [&](ArrayRef<unsigned> dims) {
int64_t acc = 1;
for (auto mDim : dims) {
int64_t size = bounds[mDim];
if (ShapedType::isDynamic(size)) {
return size;
}
acc *= size;
}
return acc;
};
return getElementCount(contractionDims->m) != 1 &&
getElementCount(contractionDims->n) != 1;
}

struct GPUPromoteMatmulOperandsPass final
: impl::GPUPromoteMatmulOperandsPassBase<GPUPromoteMatmulOperandsPass> {
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

OpBuilder builder(funcOp);
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (!isMatmulOrBatchMatmul(linalgOp)) {
if (!isNonMatvecContraction(linalgOp)) {
return;
}
builder.setInsertionPoint(linalgOp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,29 @@ func.func @matvec(%a: tensor<1x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<1x
// CHECK-LABEL: func.func @matvec
// CHECK-NOT: linalg.copy
// CHECK: return

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @generic_matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<32x128xf32> {
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : tensor<32x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32>
%mm = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%a, %b : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<32x128xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.mulf %in, %in_0 : f32
%8 = arith.addf %out, %7 : f32
linalg.yield %8 : f32
} -> tensor<32x128xf32>
return %mm : tensor<32x128xf32>
}

// CHECK-LABEL: func.func @generic_matmul
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x1024xf32>
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<1024x128xf32>
// CHECK-DAG: %[[PA:.+]] = linalg.copy {{.*}} ins(%[[A]] : tensor<32x1024xf32>)
// CHECK-DAG: %[[PB:.+]] = linalg.copy {{.*}} ins(%[[B]] : tensor<1024x128xf32>)
// CHECK: linalg.generic {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>)

0 comments on commit 0a561c4

Please sign in to comment.