Skip to content

Commit

Permalink
[mlir][vector] Add scalable vectors to tests for vector.contract
Browse files Browse the repository at this point in the history
Update the remaining tests for matrix multiplication (_matmul_) in:

  * vector-contract-to-outerproduct-transforms.mlir

with cases for scalable vectors.

Note that in order for the "vector.contract -> vector.outerproduct"
patterns to work, only the non-reduction dimension can be scalable (*).
For Matmul operations that is set to be the N dimension (i.e. rows of
the output matrix), which matches how matrix multiplication are normally
implemented for e.g. Arm's SVE. However, making the M dimension scalable
(i.e. columns of the output matrix) should work as well.

Making both parellel dimensions scalable is left as a TODO for when
support for 2-D scalable vectors is more established (this is
work-in-progress as part of the effort to support Arm's SME in MLIR).

The change in:

  * `UnrolledOuterProductGenerator`

is a "bug fix" to make sure that the conversion pattern correctly
propagates scalability when creating `arith.extf` operations.

(*) The conversion tested in this file unrolls along the reduction
dimension, which is not supported for scalable vectors.
  • Loading branch information
banach-space committed Oct 27, 2023
1 parent ea1909f commit 49c7d2e
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
return v;
Type promotedType = dstElementType;
if (vecType)
promotedType = VectorType::get(vecType.getShape(), promotedType);
promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
//
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
//
// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
//
// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
//
// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
//
// CHECK: return %[[c3]] : vector<2x[3]xf32>
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
%arg1: vector<4x[3]xf32>,
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

// CHECK-LABEL: func @matmul_0
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
Expand All @@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_0_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
-> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

// CHECK-LABEL: func @matmul_0_mixed
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
Expand All @@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_0_mixed_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
-> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

#matmat_accesses_1 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
Expand Down Expand Up @@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_1_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
-> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

#matmat_accesses_2 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (k, n)>,
Expand All @@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_2_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
-> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

#matmat_accesses_3 = [
affine_map<(m, n, k) -> (k, m)>,
affine_map<(m, n, k) -> (n, k)>,
Expand Down Expand Up @@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
return %0 : vector<2x3xf32>
}

// CHECK-LABEL: func @matmul_3_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
// CHECK: return %[[c0]] : vector<2x[3]xf32>
func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
-> vector<2x[3]xf32>
{
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
: vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
return %0 : vector<2x[3]xf32>
}

#matmat_accesses_4 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
Expand Down Expand Up @@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
return %0 : vector<3x2xf32>
}

// CHECK-LABEL: func @matmul_4_scalable
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
// CHECK: return %[[c0]] : vector<3x[2]xf32>
func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
-> vector<3x[2]xf32>
{
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
: vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
return %0 : vector<3x[2]xf32>
}

#matmat_accesses_5 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (n, m)>
]
#matmat_trait_5 = {
indexing_maps = #matmat_accesses_5,
iterator_types = ["parallel", "parallel", "reduction"]
}

// CHECK-LABEL: @masked_matvec_mk_k_m
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
Expand Down

0 comments on commit 49c7d2e

Please sign in to comment.