Skip to content

Commit

Permalink
[mlir] [VectorOps] Replace zero fma with mult for vector.contract
Browse files Browse the repository at this point in the history
More efficient implementation of the multiply-reduce pair,
no need to add in a zero vector. Microbenchmarking on AVX2
yields the following difference in vector.contract speedup
(over strict-order scalar reduction).

SPEEDUP     SIMD-fma SIMD-mul
4x4	    1.45 	 2.00
8x8	    1.40 	 1.90
32x32    	5.32 	 5.80

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D82833
  • Loading branch information
aartbik committed Jun 30, 2020
1 parent 69b2d9f commit 63b3933
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/Vector/VectorTransforms.cpp
Expand Up @@ -1291,8 +1291,8 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
Value b = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value m;
if (acc) {
Value z = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), z);
Value e = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
m = rewriter.create<vector::FMAOp>(loc, b, op.rhs(), e);
} else {
m = rewriter.create<MulFOp>(loc, b, op.rhs());
}
Expand Down Expand Up @@ -1732,7 +1732,7 @@ void ContractionOpToOuterProductOpLowering::rewrite(
/// ..
/// %x = combine %a %b ..
/// until a pure contraction is reached (no free/batch dimensions),
/// which is replaced by a fma/reduction op.
/// which is replaced by a dot-product/reduction pair.
///
/// TODO(ajcbik): break down into transpose/reshape/cast ops
/// when they become available to avoid code dup
Expand Down Expand Up @@ -1882,11 +1882,9 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
Value zero = rewriter.create<ConstantOp>(loc, lhsType,
rewriter.getZeroAttr(lhsType));
Value fma = rewriter.create<vector::FMAOp>(loc, op.lhs(), op.rhs(), zero);
Value m = rewriter.create<MulFOp>(loc, op.lhs(), op.rhs());
StringAttr kind = rewriter.getStringAttr("add");
return rewriter.create<vector::ReductionOp>(loc, resType, kind, fma,
return rewriter.create<vector::ReductionOp>(loc, resType, kind, m,
op.acc());
}
// Construct new iterator types and affine map array attribute.
Expand Down
30 changes: 13 additions & 17 deletions mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Expand Up @@ -16,8 +16,7 @@
// CHECK-SAME: %[[A:.*0]]: vector<4xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<4xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32>
// CHECK: %[[F:.*]] = mulf %[[A]], %[[B]] : vector<4xf32>
// CHECK: %[[R:.*]] = vector.reduction "add", %[[F]], %[[C]] : vector<4xf32> into f32
// CHECK: return %[[R]] : f32

Expand All @@ -42,15 +41,14 @@ func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32)
// CHECK-SAME: %[[B:.*1]]: vector<3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32>
// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32>
// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
Expand Down Expand Up @@ -78,15 +76,14 @@ func @extract_contract2(%arg0: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32>
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32>
// CHECK: %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32>
// CHECK: %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32>
// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: return %[[T9]] : vector<2xf32>
Expand Down Expand Up @@ -124,7 +121,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32>
// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
Expand All @@ -134,7 +131,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32>
// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32>
// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32>
// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32>
// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32>
// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32>
Expand All @@ -147,7 +144,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32>
// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32>
// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32>
// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32>
// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T30]] : vector<2xf32>
// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T31]] : vector<2xf32> into f32
// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32>
// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32>
Expand All @@ -157,7 +154,7 @@ func @extract_contract3(%arg0: vector<3xf32>,
// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32>
// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32>
// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32>
// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32>
// CHECK: %[[T42:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32>
// CHECK: %[[T43:.*]] = vector.reduction "add", %[[T42]], %[[T41]] : vector<2xf32> into f32
// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32>
// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32>
Expand Down Expand Up @@ -185,14 +182,13 @@ func @extract_contract4(%arg0: vector<2x2xf32>,
// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>,
// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>,
// CHECK-SAME: %[[C:.*2]]: f32
// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32>
// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32>
// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32>
// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[T1]] : vector<3xf32>
// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32>
// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32>
// CHECK: %[[T6:.*]] = mulf %[[T4]], %[[T5]] : vector<3xf32>
// CHECK: %[[T7:.*]] = vector.reduction "add", %[[T6]], %[[T3]] : vector<3xf32> into f32
// CHECK: return %[[T7]] : f32

Expand Down Expand Up @@ -229,7 +225,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32>
// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32>
// CHECK: %[[T10:.*]] = mulf %[[T0]], %[[T9]] : vector<3xf32>
// CHECK: %[[T11:.*]] = vector.reduction "add", %[[T10]], %[[C]] : vector<3xf32> into f32
// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32>
// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32>
Expand All @@ -241,7 +237,7 @@ func @full_contract1(%arg0: vector<2x3xf32>,
// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32>
// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32>
// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32>
// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32>
// CHECK: %[[T22:.*]] = mulf %[[T12]], %[[T21]] : vector<3xf32>
// CHECK: %[[T23:.*]] = vector.reduction "add", %[[T22]], %[[T11]] : vector<3xf32> into f32
// CHECK: return %[[T23]] : f32

Expand Down

0 comments on commit 63b3933

Please sign in to comment.