Skip to content

Commit

Permalink
[mlir][vector] Fix crash in vector.reduction canonicalization
Browse files Browse the repository at this point in the history
since vector.reduce support accumulator in all the cases remove the
assert assuming old definition.

Differential Revision: https://reviews.llvm.org/D129602
  • Loading branch information
ThomasRaoux committed Jul 12, 2022
1 parent cc7d966 commit 5f8cefe
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 68 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Expand Up @@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
/// memory.
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);

/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
} // namespace vector
} // namespace mlir

Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Expand Up @@ -34,11 +34,6 @@ namespace vector {
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
Value v1, Value v2);
} // namespace vector

/// Return the number of elements of basis, `0` if empty.
Expand Down
66 changes: 53 additions & 13 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Expand Up @@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
reductionOp.getVector(),
rewriter.getI64ArrayAttr(0));

if (Value acc = reductionOp.getAcc()) {
assert(reductionOp.getType().isa<FloatType>());
switch (reductionOp.getKind()) {
case CombiningKind::ADD:
result = rewriter.create<arith::AddFOp>(loc, result, acc);
break;
case CombiningKind::MUL:
result = rewriter.create<arith::MulFOp>(loc, result, acc);
break;
default:
assert(false && "invalid op!");
}
}
if (Value acc = reductionOp.getAcc())
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
result, acc);

rewriter.replaceOp(reductionOp, result);
return success();
Expand Down Expand Up @@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
}

Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value v2) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type t2 = getElementTypeOrSelf(v2.getType());
switch (kind) {
case CombiningKind::ADD:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for ADD reduction");
case CombiningKind::AND:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
case CombiningKind::MAXF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
case CombiningKind::MINF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
case CombiningKind::MINSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
case CombiningKind::MAXUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
case CombiningKind::MINUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
case CombiningKind::MUL:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for MUL reduction");
case CombiningKind::OR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
case CombiningKind::XOR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
};
llvm_unreachable("unknown CombiningKind");
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
50 changes: 0 additions & 50 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Expand Up @@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}

Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value v2) {
Type t1 = getElementTypeOrSelf(v1.getType());
Type t2 = getElementTypeOrSelf(v2.getType());
switch (kind) {
case CombiningKind::ADD:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::AddIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::AddFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for ADD reduction");
case CombiningKind::AND:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::AndIOp>(loc, v1, v2);
case CombiningKind::MAXF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
case CombiningKind::MINF:
assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
"expected float values");
return b.createOrFold<arith::MinFOp>(loc, v1, v2);
case CombiningKind::MAXSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
case CombiningKind::MINSI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
case CombiningKind::MAXUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
case CombiningKind::MINUI:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
case CombiningKind::MUL:
if (t1.isIntOrIndex() && t2.isIntOrIndex())
return b.createOrFold<arith::MulIOp>(loc, v1, v2);
else if (t1.isa<FloatType>() && t2.isa<FloatType>())
return b.createOrFold<arith::MulFOp>(loc, v1, v2);
llvm_unreachable("invalid value types for MUL reduction");
case CombiningKind::OR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::OrIOp>(loc, v1, v2);
case CombiningKind::XOR:
assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
};
llvm_unreachable("unknown CombiningKind");
}

/// Return the number of elements of basis, `0` if empty.
int64_t mlir::computeMaxLinearIndex(ArrayRef<int64_t> basis) {
if (basis.empty())
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Expand Up @@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 {

// -----

// CHECK-LABEL: func @reduce_one_element_vector_maxf
// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32)
// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32>
// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32
// CHECK: return %[[S]]
func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 {
%s = vector.reduction <maxf>, %a, %b : vector<1xf32> into f32
return %s : f32
}

// -----

// CHECK-LABEL: func @bitcast(
// CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> {
// CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>
Expand Down

0 comments on commit 5f8cefe

Please sign in to comment.