diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 36eb29d53766f..ba7cd6ae91833 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -255,6 +255,8 @@ class DXILOpMapping; def IsInf : DXILOpMapping<9, isSpecialFloat, int_dx_isinf, "Determines if the specified value is infinite.", [llvm_i1_ty, llvm_halforfloat_ty]>; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index 0db42bc0a0fb6..b46564702c7aa 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -33,6 +33,7 @@ using namespace llvm; static bool isIntrinsicExpansion(Function &F) { switch (F.getIntrinsicID()) { + case Intrinsic::abs: case Intrinsic::exp: case Intrinsic::dx_any: case Intrinsic::dx_clamp: @@ -46,6 +47,26 @@ static bool isIntrinsicExpansion(Function &F) { return false; } +static bool expandAbs(CallInst *Orig) { + Value *X = Orig->getOperand(0); + IRBuilder<> Builder(Orig->getParent()); + Builder.SetInsertPoint(Orig); + Type *Ty = X->getType(); + Type *EltTy = Ty->getScalarType(); + Constant *Zero = Ty->isVectorTy() + ? ConstantVector::getSplat( + ElementCount::getFixed( + cast(Ty)->getNumElements()), + ConstantInt::get(EltTy, 0)) + : ConstantInt::get(EltTy, 0); + auto *V = Builder.CreateSub(Zero, X); + auto *MaxCall = + Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max"); + Orig->replaceAllUsesWith(MaxCall); + Orig->eraseFromParent(); + return true; +} + static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { assert(DotIntrinsic == Intrinsic::dx_sdot || DotIntrinsic == Intrinsic::dx_udot); @@ -213,6 +234,8 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { static bool expandIntrinsic(Function &F, CallInst *Orig) { switch (F.getIntrinsicID()) { + case Intrinsic::abs: + return expandAbs(Orig); case Intrinsic::exp: return expandExpIntrinsic(Orig); case Intrinsic::dx_any: diff --git a/llvm/test/CodeGen/DirectX/abs-vec.ll b/llvm/test/CodeGen/DirectX/abs-vec.ll new file mode 100644 index 0000000000000..1c40555eb390c --- /dev/null +++ b/llvm/test/CodeGen/DirectX/abs-vec.ll @@ -0,0 +1,34 @@ +; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s + +; Make sure dxil operation function calls for abs are generated for int vectors. + +; CHECK-LABEL: abs_i16Vec2 +define noundef <2 x i16> @abs_i16Vec2(<2 x i16> noundef %a) #0 { +entry: +; CHECK: sub <2 x i16> zeroinitializer, %a +; CHECK: call <2 x i16> @llvm.smax.v2i16(<2 x i16> %a, <2 x i16> %{{.*}}) + %elt.abs = call <2 x i16> @llvm.abs.v2i16(<2 x i16> %a, i1 false) + ret <2 x i16> %elt.abs +} + +; CHECK-LABEL: abs_i32Vec3 +define noundef <3 x i32> @abs_i32Vec3(<3 x i32> noundef %a) #0 { +entry: +; CHECK: sub <3 x i32> zeroinitializer, %a +; CHECK: call <3 x i32> @llvm.smax.v3i32(<3 x i32> %a, <3 x i32> %{{.*}}) + %elt.abs = call <3 x i32> @llvm.abs.v3i32(<3 x i32> %a, i1 false) + ret <3 x i32> %elt.abs +} + +; CHECK-LABEL: abs_i64Vec4 +define noundef <4 x i64> @abs_i64Vec4(<4 x i64> noundef %a) #0 { +entry: +; CHECK: sub <4 x i64> zeroinitializer, %a +; CHECK: call <4 x i64> @llvm.smax.v4i64(<4 x i64> %a, <4 x i64> %{{.*}}) + %elt.abs = call <4 x i64> @llvm.abs.v4i64(<4 x i64> %a, i1 false) + ret <4 x i64> %elt.abs +} + +declare <2 x i16> @llvm.abs.v2i16(<2 x i16>, i1 immarg) +declare <3 x i32> @llvm.abs.v3i32(<3 x i32>, i1 immarg) +declare <4 x i64> @llvm.abs.v4i64(<4 x i64>, i1 immarg) diff --git a/llvm/test/CodeGen/DirectX/abs.ll b/llvm/test/CodeGen/DirectX/abs.ll new file mode 100644 index 0000000000000..822580e8c089a --- /dev/null +++ b/llvm/test/CodeGen/DirectX/abs.ll @@ -0,0 +1,38 @@ +; RUN: opt -S -dxil-intrinsic-expansion < %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s --check-prefixes=CHECK,DOPCHECK + +; Make sure dxil operation function calls for abs are generated for int16_t/int/int64_t. + +; CHECK-LABEL: abs_i16 +define noundef i16 @abs_i16(i16 noundef %a) { +entry: +; CHECK: sub i16 0, %a +; EXPCHECK: call i16 @llvm.smax.i16(i16 %a, i16 %{{.*}}) +; DOPCHECK: call i16 @dx.op.binary.i16(i32 37, i16 %a, i16 %{{.*}}) + %elt.abs = call i16 @llvm.abs.i16(i16 %a, i1 false) + ret i16 %elt.abs +} + +; CHECK-LABEL: abs_i32 +define noundef i32 @abs_i32(i32 noundef %a) { +entry: +; CHECK: sub i32 0, %a +; EXPCHECK: call i32 @llvm.smax.i32(i32 %a, i32 %{{.*}}) +; DOPCHECK: call i32 @dx.op.binary.i32(i32 37, i32 %a, i32 %{{.*}}) + %elt.abs = call i32 @llvm.abs.i32(i32 %a, i1 false) + ret i32 %elt.abs +} + +; CHECK-LABEL: abs_i64 +define noundef i64 @abs_i64(i64 noundef %a) { +entry: +; CHECK: sub i64 0, %a +; EXPCHECK: call i64 @llvm.smax.i64(i64 %a, i64 %{{.*}}) +; DOPCHECK: call i64 @dx.op.binary.i64(i32 37, i64 %a, i64 %{{.*}}) + %elt.abs = call i64 @llvm.abs.i64(i64 %a, i1 false) + ret i64 %elt.abs +} + +declare i16 @llvm.abs.i16(i16, i1 immarg) +declare i32 @llvm.abs.i32(i32, i1 immarg) +declare i64 @llvm.abs.i64(i64, i1 immarg) diff --git a/llvm/test/CodeGen/DirectX/fabs.ll b/llvm/test/CodeGen/DirectX/fabs.ll new file mode 100644 index 0000000000000..3b3f8aa9a4a92 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/fabs.ll @@ -0,0 +1,32 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for abs are generated for float, half, and double. + + +; CHECK-LABEL: fabs_half +define noundef half @fabs_half(half noundef %a) { +entry: + ; CHECK: call half @dx.op.unary.f16(i32 6, half %{{.*}}) + %elt.abs = call half @llvm.fabs.f16(half %a) + ret half %elt.abs +} + +; CHECK-LABEL: fabs_float +define noundef float @fabs_float(float noundef %a) { +entry: +; CHECK: call float @dx.op.unary.f32(i32 6, float %{{.*}}) + %elt.abs = call float @llvm.fabs.f32(float %a) + ret float %elt.abs +} + +; CHECK-LABEL: fabs_double +define noundef double @fabs_double(double noundef %a) { +entry: +; CHECK: call double @dx.op.unary.f64(i32 6, double %{{.*}}) + %elt.abs = call double @llvm.fabs.f64(double %a) + ret double %elt.abs +} + +declare half @llvm.fabs.f16(half) +declare float @llvm.fabs.f32(float) +declare double @llvm.fabs.f64(double)