Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 16 additions & 28 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
Attribute propertiesAttr,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);

/// Return "true" if the given type is an unsupported floating point type. In
/// case of a vector type, return "true" if the element type is an unsupported
/// floating point type.
bool isUnsupportedFloatingPointType(const TypeConverter &typeConverter,
Type type);
} // namespace detail
} // namespace LLVM

Expand Down Expand Up @@ -97,43 +103,25 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
using Super = VectorConvertToLLVMPattern<SourceOp, TargetOp>;

/// Return the given type if it's a floating point type. If the given type is
/// a vector type, return its element type if it's a floating point type.
static FloatType getFloatingPointType(Type type) {
if (auto floatType = dyn_cast<FloatType>(type))
return floatType;
if (auto vecType = dyn_cast<VectorType>(type))
return dyn_cast<FloatType>(vecType.getElementType());
return nullptr;
}

LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");

// The pattern should not apply if a floating-point operand is converted to
// a non-floating-point type. This indicates that the floating point type
// is not supported by the LLVM lowering. (Such types are converted to
// integers.)
auto checkType = [&](Value v) -> LogicalResult {
FloatType floatType = getFloatingPointType(v.getType());
if (!floatType)
return success();
Type convertedType = this->getTypeConverter()->convertType(floatType);
if (!isa_and_nonnull<FloatType>(convertedType))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
return success();
};
// Bail on unsupported floating point types. (These are type-converted to
// integer types.)
if (FailOnUnsupportedFP) {
for (Value operand : op->getOperands())
if (failed(checkType(operand)))
return failure();
if (failed(checkType(op->getResult(0))))
return failure();
if (LLVM::detail::isUnsupportedFloatingPointType(
*this->getTypeConverter(), operand.getType()))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
if (LLVM::detail::isUnsupportedFloatingPointType(
*this->getTypeConverter(), op->getResult(0).getType()))
return rewriter.notifyMatchFailure(op,
"unsupported floating point type");
}

// Determine attributes for the target op
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,10 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
LogicalResult
CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (LLVM::detail::isUnsupportedFloatingPointType(*this->getTypeConverter(),
op.getLhs().getType()))
return rewriter.notifyMatchFailure(op, "unsupported floating point type");

Type operandType = adaptor.getLhs().getType();
Type resultType = op.getResult().getType();
LLVM::FastmathFlags fmf =
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,24 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
return handleMultidimensionalVectors(op, operands, typeConverter, callback,
rewriter);
}

/// Return the given type if it's a floating point type. If the given type is
/// a vector type, return its element type if it's a floating point type.
static FloatType getFloatingPointType(Type type) {
if (auto floatType = dyn_cast<FloatType>(type))
return floatType;
if (auto vecType = dyn_cast<VectorType>(type))
return dyn_cast<FloatType>(vecType.getElementType());
return nullptr;
}

bool LLVM::detail::isUnsupportedFloatingPointType(
const TypeConverter &typeConverter, Type type) {
FloatType floatType = getFloatingPointType(type);
if (!floatType)
return false;
Type convertedType = typeConverter.convertType(floatType);
if (!convertedType)
return true;
return !isa<FloatType>(convertedType);
}
10 changes: 7 additions & 3 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -770,12 +770,14 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
// CHECK: arith.cmpf {{.*}} : f4E2M1FN
// CHECK: llvm.select {{.*}} : i1, i4
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
%3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
%3 = arith.cmpf oeq, %arg0, %arg3 : f4E2M1FN
%4 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
return
}

Expand All @@ -785,9 +787,11 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2
// CHECK: llvm.fadd {{.*}} : f32
// CHECK: llvm.fadd {{.*}} : vector<4xf32>
// CHECK-COUNT-4: llvm.fadd {{.*}} : vector<8xf32>
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>) -> (f32, vector<4xf32>, vector<4x8xf32>) {
// CHECK: llvm.fcmp {{.*}} : f32
func.func @supported_fp_type(%arg0: f32, %arg1: vector<4xf32>, %arg2: vector<4x8xf32>, %arg3: f32) {
%0 = arith.addf %arg0, %arg0 : f32
%1 = arith.addf %arg1, %arg1 : vector<4xf32>
%2 = arith.addf %arg2, %arg2 : vector<4x8xf32>
return %0, %1, %2 : f32, vector<4xf32>, vector<4x8xf32>
%3 = arith.cmpf oeq, %arg0, %arg3 : f32
return
}