Skip to content

Commit

Permalink
lowerParallel is also called on unit-size, one-sided reduction dims
Browse files Browse the repository at this point in the history
  • Loading branch information
bjacob committed Jul 13, 2022
1 parent 3968936 commit 6870a50
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -527,11 +527,12 @@ class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
vector::VectorTransformsOptions vectorTransformOptions;
FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex,
PatternRewriter &rewriter) const;
// Lower one reduction dimension.
Value lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
FailureOr<Value> lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const;
};

} // namespace vector
Expand Down
96 changes: 70 additions & 26 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
if (!batchDimMap.empty()) {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter));
auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}

Expand All @@ -1812,8 +1815,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter));
auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
}
Expand All @@ -1822,26 +1827,33 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
rewriter.replaceOp(
op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter));
auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}
}

// Lower the first remaining reduction dimension.
if (!contractingDimMap.empty()) {
rewriter.replaceOp(op, lowerReduction(op, rewriter));
auto newOp = lowerReduction(op, rewriter);
if (failed(newOp))
return failure();
rewriter.replaceOp(op, newOp.value());
return success();
}

return failure();
}

// Lower one parallel dimension.
// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
// TODO: consider reusing existing contract unrolling
Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
int64_t lhsIndex, int64_t rhsIndex,
PatternRewriter &rewriter) const {
FailureOr<Value>
ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex,
PatternRewriter &rewriter) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
Expand All @@ -1851,18 +1863,34 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
int64_t dimSize = -1;
if (lhsIndex >= 0) {
iterIndex = iMap[0].getDimPosition(lhsIndex);
assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
"parallel index should be free in LHS or batch in LHS/RHS");
if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
<< " to map to the same dimension";
});
dimSize = lhsType.getDimSize(lhsIndex);
} else {
assert(rhsIndex >= 0 && "missing parallel index");
} else if (rhsIndex >= 0) {
iterIndex = iMap[1].getDimPosition(rhsIndex);
dimSize = rhsType.getDimSize(rhsIndex);
}
assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
Optional<int64_t> lookup = getResultIndex(iMap[2], iterIndex);
assert(lookup.has_value() && "parallel index not listed in reduction");
int64_t resIndex = lookup.getValue();
if (iterIndex < 0)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected either lhsIndex=" << lhsIndex
<< " or rhsIndex=" << rhsIndex << " to be nonnegative";
});
// getValueOr(-1) means that we tolerate a dimension not appearing
// in the result map. That can't happen for actual parallel iterators, but
// the caller ContractionOpLowering::matchAndRewrite is currently calling
// lowerParallel also for the case of unit-size reduction dims appearing only
// on one of LHS or RHS, not both. At the moment, such cases are created by
// CastAwayContractionLeadingOneDim, so we need to either support that or
// modify that pattern.
int64_t resIndex = getResultIndex(iMap[2], iterIndex).getValueOr(-1);
if (resIndex == -1 && dimSize != 1)
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected the dimension for iterIndex=" << iterIndex
<< " to either appear in the result map, or to be a unit dimension";
});
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
Expand All @@ -1888,33 +1916,49 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
}

// Lower one reduction dimension.
Value ContractionOpLowering::lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const {
FailureOr<Value>
ContractionOpLowering::lowerReduction(vector::ContractionOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
Type resType = op.getResultType();
assert(!resType.isa<VectorType>());
if (resType.isa<VectorType>())
return rewriter.notifyMatchFailure(op,
"did not expect a VectorType result");
bool isInt = resType.isa<IntegerType>();
// Use iterator index 0.
int64_t iterIndex = 0;
SmallVector<AffineMap, 4> iMap = op.getIndexingMaps();
Optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
Optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
assert(lookupLhs.has_value() && "missing LHS parallel index");
assert(lookupRhs.has_value() && "missing RHS parallel index");
if (!lookupLhs.hasValue())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
});
if (!lookupRhs.hasValue())
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
});
int64_t lhsIndex = lookupLhs.getValue();
int64_t rhsIndex = lookupRhs.getValue();
int64_t dimSize = lhsType.getDimSize(lhsIndex);
assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape");
if (dimSize != rhsType.getDimSize(rhsIndex))
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "expect LHS dimension " << lhsIndex
<< " to have the same size as RHS dimension " << rhsIndex;
});
// Base case.
if (lhsType.getRank() == 1) {
assert(rhsType.getRank() == 1 && "corrupt contraction");
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
if (auto acc = op.getAcc())
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc);
return rewriter.create<vector::ReductionOp>(loc, kind, m);
return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
.getResult();
return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
}
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Vector/vector-contract-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,34 @@ func.func @genbool_var_3d(%arg0: index, %arg1: index, %arg2: index) -> vector<2x
return %0 : vector<2x1x7xi1>
}

// CHECK-LABEL: @contract_one_sided_unit_reduction_dim
// CHECK-SAME: (%[[A0:.+]]: vector<1x2xi32>, %[[A1:.+]]: vector<2x2xi32>, %[[A2:.+]]: vector<2xi32>)
// CHECK-DAG: %[[C:.+]] = arith.constant dense<0> : vector<2xi32>
// CHECK-DAG: %[[E00:.+]] = vector.extract %[[A0]][0] : vector<1x2xi32>
// CHECK-DAG: %[[E10:.+]] = vector.extract %[[A1]][0] : vector<2x2xi32>
// CHECK: %[[M0:.+]] = arith.muli %[[E10]], %[[E00]] : vector<2xi32>
// CHECK: %[[R0:.+]] = vector.reduction <add>, %[[M0]] : vector<2xi32> into i32
// CHECK: %[[I0:.+]] = vector.insert %[[R0]], %[[C]] [0] : i32 into vector<2xi32>
// CHECK: %[[E11:.+]] = vector.extract %[[A1]][1] : vector<2x2xi32>
// CHECK: %[[M1:.+]] = arith.muli %[[E11]], %[[E00]] : vector<2xi32>
// CHECK: %[[R1:.+]] = vector.reduction <add>, %[[M1]] : vector<2xi32> into i32
// CHECK: %[[I1:.+]] = vector.insert %[[R1]], %[[I0]] [1] : i32 into vector<2xi32>
// CHECK: %[[S:.+]] = arith.addi %[[I1]], %[[A2]] : vector<2xi32>
// CHECK: return %[[S]] : vector<2xi32>

func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1 : vector<2x2xi32>, %arg2 : vector<2xi32>) -> vector<2xi32> {
%res = vector.contract {
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d1)>
],
iterator_types = ["reduction", "parallel", "reduction"],
kind = #vector.kind<add>
} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2x2xi32>, vector<2xi32> into vector<2xi32>
return %res : vector<2xi32>
}

#matmat_accesses_0 = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
Expand Down

0 comments on commit 6870a50

Please sign in to comment.