Skip to content

Commit 2cb130f

Browse files
committed
[mlir][StandardToSPIRV] Add support for lowering uitofp to SPIR-V
- Extend spirv::ConstantOp::getZero/One to handle float, vector of int, and vector of float. - Refactor ZeroExtendI1Pattern to use getZero/One methods. - Add one more test for lowering std.zexti which extends vector<4xi1> to vector<4xi64>. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D95120
1 parent 16d4bbe commit 2cb130f

File tree

3 files changed

+130
-11
lines changed

3 files changed

+130
-11
lines changed

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -481,16 +481,32 @@ class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
481481
auto dstType =
482482
this->getTypeConverter()->convertType(op.getResult().getType());
483483
Location loc = op.getLoc();
484-
Attribute zeroAttr, oneAttr;
485-
if (auto vectorType = dstType.dyn_cast<VectorType>()) {
486-
zeroAttr = DenseElementsAttr::get(vectorType, 0);
487-
oneAttr = DenseElementsAttr::get(vectorType, 1);
488-
} else {
489-
zeroAttr = IntegerAttr::get(dstType, 0);
490-
oneAttr = IntegerAttr::get(dstType, 1);
491-
}
492-
Value zero = rewriter.create<ConstantOp>(loc, zeroAttr);
493-
Value one = rewriter.create<ConstantOp>(loc, oneAttr);
484+
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
485+
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
486+
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
487+
op, dstType, operands.front(), one, zero);
488+
return success();
489+
}
490+
};
491+
492+
/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
493+
/// i1.
494+
class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
495+
public:
496+
using OpConversionPattern<UIToFPOp>::OpConversionPattern;
497+
498+
LogicalResult
499+
matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
500+
ConversionPatternRewriter &rewriter) const override {
501+
auto srcType = operands.front().getType();
502+
if (!isBoolScalarOrVector(srcType))
503+
return failure();
504+
505+
auto dstType =
506+
this->getTypeConverter()->convertType(op.getResult().getType());
507+
Location loc = op.getLoc();
508+
Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
509+
Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
494510
rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
495511
op, dstType, operands.front(), one, zero);
496512
return success();
@@ -1098,8 +1114,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
10981114
ReturnOpPattern, SelectOpPattern,
10991115

11001116
// Type cast patterns
1101-
ZeroExtendI1Pattern, TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
1117+
UIToFPI1Pattern, ZeroExtendI1Pattern,
1118+
TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
11021119
TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
1120+
TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
11031121
TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
11041122
TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
11051123
TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "mlir/IR/OpDefinition.h"
2626
#include "mlir/IR/OpImplementation.h"
2727
#include "mlir/Interfaces/CallInterfaces.h"
28+
#include "llvm/ADT/APFloat.h"
29+
#include "llvm/ADT/APInt.h"
2830
#include "llvm/ADT/StringExtras.h"
2931
#include "llvm/ADT/bit.h"
3032

@@ -1581,6 +1583,25 @@ spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
15811583
return builder.create<spirv::ConstantOp>(
15821584
loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
15831585
}
1586+
if (auto floatType = type.dyn_cast<FloatType>()) {
1587+
return builder.create<spirv::ConstantOp>(
1588+
loc, type, builder.getFloatAttr(floatType, 0.0));
1589+
}
1590+
if (auto vectorType = type.dyn_cast<VectorType>()) {
1591+
Type elemType = vectorType.getElementType();
1592+
if (elemType.isa<IntegerType>()) {
1593+
return builder.create<spirv::ConstantOp>(
1594+
loc, type,
1595+
DenseElementsAttr::get(vectorType,
1596+
IntegerAttr::get(elemType, 0.0).getValue()));
1597+
}
1598+
if (elemType.isa<FloatType>()) {
1599+
return builder.create<spirv::ConstantOp>(
1600+
loc, type,
1601+
DenseFPElementsAttr::get(vectorType,
1602+
FloatAttr::get(elemType, 0.0).getValue()));
1603+
}
1604+
}
15841605

15851606
llvm_unreachable("unimplemented types for ConstantOp::getZero()");
15861607
}
@@ -1595,6 +1616,25 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
15951616
return builder.create<spirv::ConstantOp>(
15961617
loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
15971618
}
1619+
if (auto floatType = type.dyn_cast<FloatType>()) {
1620+
return builder.create<spirv::ConstantOp>(
1621+
loc, type, builder.getFloatAttr(floatType, 1.0));
1622+
}
1623+
if (auto vectorType = type.dyn_cast<VectorType>()) {
1624+
Type elemType = vectorType.getElementType();
1625+
if (elemType.isa<IntegerType>()) {
1626+
return builder.create<spirv::ConstantOp>(
1627+
loc, type,
1628+
DenseElementsAttr::get(vectorType,
1629+
IntegerAttr::get(elemType, 1.0).getValue()));
1630+
}
1631+
if (elemType.isa<FloatType>()) {
1632+
return builder.create<spirv::ConstantOp>(
1633+
loc, type,
1634+
DenseFPElementsAttr::get(vectorType,
1635+
FloatAttr::get(elemType, 1.0).getValue()));
1636+
}
1637+
}
15981638

15991639
llvm_unreachable("unimplemented types for ConstantOp::getOne()");
16001640
}

mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,58 @@ func @sitofp2(%arg0 : i64) -> f64 {
568568
return %0 : f64
569569
}
570570

571+
// CHECK-LABEL: @uitofp_i16_f32
572+
func @uitofp_i16_f32(%arg0: i16) -> f32 {
573+
// CHECK: spv.ConvertUToF %{{.*}} : i16 to f32
574+
%0 = std.uitofp %arg0 : i16 to f32
575+
return %0 : f32
576+
}
577+
578+
// CHECK-LABEL: @uitofp_i32_f32
579+
func @uitofp_i32_f32(%arg0 : i32) -> f32 {
580+
// CHECK: spv.ConvertUToF %{{.*}} : i32 to f32
581+
%0 = std.uitofp %arg0 : i32 to f32
582+
return %0 : f32
583+
}
584+
585+
// CHECK-LABEL: @uitofp_i1_f32
586+
func @uitofp_i1_f32(%arg0 : i1) -> f32 {
587+
// CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32
588+
// CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32
589+
// CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32
590+
%0 = std.uitofp %arg0 : i1 to f32
591+
return %0 : f32
592+
}
593+
594+
// CHECK-LABEL: @uitofp_i1_f64
595+
func @uitofp_i1_f64(%arg0 : i1) -> f64 {
596+
// CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f64
597+
// CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f64
598+
// CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64
599+
%0 = std.uitofp %arg0 : i1 to f64
600+
return %0 : f64
601+
}
602+
603+
// CHECK-LABEL: @uitofp_vec_i1_f32
604+
func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> {
605+
// CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32>
606+
// CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32>
607+
// CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32>
608+
%0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32>
609+
return %0 : vector<4xf32>
610+
}
611+
612+
// CHECK-LABEL: @uitofp_vec_i1_f64
613+
spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" {
614+
// CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf64>
615+
// CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf64>
616+
// CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64>
617+
%0 = spv.constant dense<0.000000e+00> : vector<4xf64>
618+
%1 = spv.constant dense<1.000000e+00> : vector<4xf64>
619+
%2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64>
620+
spv.ReturnValue %2 : vector<4xf64>
621+
}
622+
571623
// CHECK-LABEL: @zexti1
572624
func @zexti1(%arg0: i16) -> i64 {
573625
// CHECK: spv.UConvert %{{.*}} : i16 to i64
@@ -600,6 +652,15 @@ func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> {
600652
return %0 : vector<4xi32>
601653
}
602654

655+
// CHECK-LABEL: @zexti5
656+
func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> {
657+
// CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi64>
658+
// CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi64>
659+
// CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64>
660+
%0 = std.zexti %arg0 : vector<4xi1> to vector<4xi64>
661+
return %0 : vector<4xi64>
662+
}
663+
603664
// CHECK-LABEL: @trunci1
604665
func @trunci1(%arg0 : i64) -> i16 {
605666
// CHECK: spv.SConvert %{{.*}} : i64 to i16

0 commit comments

Comments
 (0)