Skip to content

Commit

Permalink
[flang] Add PowerPC vec_splat, vec_splats and vec_splat_s32 intrinsic
Browse files Browse the repository at this point in the history
Co-authored-by: Paul Scoropan <1paulscoropan@gmail.com>

Differential Revision: https://reviews.llvm.org/D157728
  • Loading branch information
kkwli committed Aug 14, 2023
1 parent bfc965c commit f50eaea
Show file tree
Hide file tree
Showing 7 changed files with 1,794 additions and 9 deletions.
16 changes: 16 additions & 0 deletions flang/include/flang/Optimizer/Builder/PPCIntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ enum class VecOp {
Sldw,
Sll,
Slo,
Splat,
Splat_s32,
Splats,
Sr,
Srl,
Sro,
Expand Down Expand Up @@ -113,6 +116,15 @@ static inline VecTypeInfo getVecTypeFromFir(mlir::Value firVec) {
return getVecTypeFromFirType(firVec.getType());
}

// Calculates the vector length and returns a VecTypeInfo with element type and
// length.
static inline VecTypeInfo getVecTypeFromEle(mlir::Value ele) {
VecTypeInfo vecTyInfo;
vecTyInfo.eleTy = ele.getType();
vecTyInfo.len = 16 / (vecTyInfo.eleTy.getIntOrFloatBitWidth() / 8);
return vecTyInfo;
}

// Converts array of fir vectors to mlir vectors.
static inline llvm::SmallVector<mlir::Value, 4>
convertVecArgs(fir::FirOpBuilder &builder, mlir::Location loc,
Expand Down Expand Up @@ -209,6 +221,10 @@ struct PPCIntrinsicLibrary : IntrinsicLibrary {

template <VecOp>
void genVecXStore(llvm::ArrayRef<fir::ExtendedValue>);

template <VecOp vop>
fir::ExtendedValue genVecSplat(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args);
};

const IntrinsicHandler *findPPCIntrinsicHandler(llvm::StringRef name);
Expand Down
62 changes: 62 additions & 0 deletions flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@ static constexpr IntrinsicHandler ppcHandlers[]{
&PI::genVecShift<VecOp::Slo>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_splat",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecSplat<VecOp::Splat>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_splat_s32_",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecSplat<VecOp::Splat_s32>),
{{{"arg1", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_splats",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecSplat<VecOp::Splats>),
{{{"arg1", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sr",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sr>),
Expand Down Expand Up @@ -1608,6 +1623,53 @@ PPCIntrinsicLibrary::genVecShift(mlir::Type resultType,
return shftRes;
}

// VEC_SPLAT, VEC_SPLATS, VEC_SPLAT_S32
template <VecOp vop>
fir::ExtendedValue
PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
auto context{builder.getContext()};
auto argBases{getBasesForArgs(args)};

mlir::vector::SplatOp splatOp{nullptr};
mlir::Type retTy{nullptr};
switch (vop) {
case VecOp::Splat: {
assert(args.size() == 2);
auto vecTyInfo{getVecTypeFromFir(argBases[0])};

auto extractOp{genVecExtract(resultType, args)};
splatOp = builder.create<mlir::vector::SplatOp>(
loc, *(extractOp.getUnboxed()), vecTyInfo.toMlirVectorType(context));
retTy = vecTyInfo.toFirVectorType();
break;
}
case VecOp::Splats: {
assert(args.size() == 1);
auto vecTyInfo{getVecTypeFromEle(argBases[0])};

splatOp = builder.create<mlir::vector::SplatOp>(
loc, argBases[0], vecTyInfo.toMlirVectorType(context));
retTy = vecTyInfo.toFirVectorType();
break;
}
case VecOp::Splat_s32: {
assert(args.size() == 1);
auto eleTy{builder.getIntegerType(32)};
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};

// the intrinsic always returns vector(integer(4))
splatOp = builder.create<mlir::vector::SplatOp>(
loc, intOp, mlir::VectorType::get(4, eleTy));
retTy = fir::VectorType::get(4, eleTy);
break;
}
default:
llvm_unreachable("invalid vector operation for generator");
}
return builder.createConvert(loc, retTy, splatOp);
}

const char *getMmaIrIntrName(MMAOp mmaOp) {
switch (mmaOp) {
case MMAOp::AssembleAcc:
Expand Down
25 changes: 24 additions & 1 deletion flang/lib/Semantics/check-call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ bool CheckArgumentIsConstantExprInRange(

if (*scalarValue < lowerBound || *scalarValue > upperBound) {
messages.Say(
"Argument #%d must be a constant expression in range %d-%d"_err_en_US,
"Argument #%d must be a constant expression in range %d to %d"_err_en_US,
index + 1, lowerBound, upperBound);
return false;
}
Expand Down Expand Up @@ -1560,6 +1560,29 @@ bool CheckPPCIntrinsic(const Symbol &generic, const Symbol &specific,
if (specific.name().ToString().compare(0, 16, "__ppc_vec_permi_") == 0) {
return CheckArgumentIsConstantExprInRange(actuals, 2, 0, 3, messages);
}
if (specific.name().ToString().compare(0, 21, "__ppc_vec_splat_s32__") == 0) {
return CheckArgumentIsConstantExprInRange(actuals, 0, -16, 15, messages);
}
if (specific.name().ToString().compare(0, 16, "__ppc_vec_splat_") == 0) {
// The value of arg2 in vec_splat must be a constant expression that is
// greater than or equal to 0, and less than the number of elements in arg1.
auto *expr{actuals[0].value().UnwrapExpr()};
auto type{characteristics::TypeAndShape::Characterize(*expr, context)};
assert(type && "unknown type");
const auto *derived{evaluate::GetDerivedTypeSpec(type.value().type())};
if (derived && derived->IsVectorType()) {
for (const auto &pair : derived->parameters()) {
if (pair.first == "element_kind") {
auto vecElemKind{Fortran::evaluate::ToInt64(pair.second.GetExplicit())
.value_or(0)};
auto numElem{vecElemKind == 0 ? 0 : (16 / vecElemKind)};
return CheckArgumentIsConstantExprInRange(
actuals, 1, 0, numElem - 1, messages);
}
}
} else
assert(false && "vector type is expected");
}
return false;
}

Expand Down
113 changes: 109 additions & 4 deletions flang/module/__ppc_intrinsics.f90
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,30 @@ elemental vector(real(VKIND1)) function elem_func_vr##VKIND1##vr##VKIND2(arg1);
end function ;
#define ELEM_FUNC_VRVR(VKIND) ELEM_FUNC_VRVR_2(VKIND, VKIND)

! vector(i) function f(i)
#define ELEM_FUNC_VII_2(RKIND, VKIND) \
elemental vector(integer(RKIND)) function elem_func_vi##RKIND##i##VKIND(arg1); \
integer(VKIND), intent(in) :: arg1; \
end function ;
#define ELEM_FUNC_VII(VKIND) ELEM_FUNC_VII_2(VKIND, VKIND)

! vector(r) function f(r)
#define ELEM_FUNC_VRR(VKIND) \
elemental vector(real(VKIND)) function elem_func_vr##VKIND##r##VKIND(arg1); \
real(VKIND), intent(in) :: arg1; \
end function ;

ELEM_FUNC_VIVI(1) ELEM_FUNC_VIVI(2) ELEM_FUNC_VIVI(4) ELEM_FUNC_VIVI(8)
ELEM_FUNC_VUVU(1)
ELEM_FUNC_VRVR_2(4,8) ELEM_FUNC_VRVR_2(8,4)
ELEM_FUNC_VRVR(4) ELEM_FUNC_VRVR(8)
ELEM_FUNC_VII_2(4,1) ELEM_FUNC_VII_2(4,2) ELEM_FUNC_VII_2(4,8)
ELEM_FUNC_VII(1) ELEM_FUNC_VII(2) ELEM_FUNC_VII(4) ELEM_FUNC_VII(8)
ELEM_FUNC_VRR(4) ELEM_FUNC_VRR(8)

#undef ELEM_FUNC_VRR
#undef ELEM_FUNC_VII
#undef ELEM_FUNC_VII_2
#undef ELEM_FUNC_VRVR
#undef ELEM_FUNC_VRVR_2
#undef ELEM_FUNC_VUVU
Expand Down Expand Up @@ -150,6 +169,30 @@ elemental real(VKIND) function elem_func_r##VKIND##vr##VKIND##i(arg1, arg2); \
!dir$ ignore_tkr(k) arg2; \
end function ;

! vector(i) function f(vector(i), i)
#define ELEM_FUNC_VIVII0(VKIND) \
elemental vector(integer(VKIND)) function elem_func_vi##VKIND##vi##VKIND##i0(arg1, arg2); \
vector(integer(VKIND)), intent(in) :: arg1; \
integer(8), intent(in) :: arg2; \
!dir$ ignore_tkr(k) arg2; \
end function ;

! vector(u) function f(vector(u), i)
#define ELEM_FUNC_VUVUI0(VKIND) \
elemental vector(unsigned(VKIND)) function elem_func_vu##VKIND##vu##VKIND##i0(arg1, arg2); \
vector(unsigned(VKIND)), intent(in) :: arg1; \
integer(8), intent(in) :: arg2; \
!dir$ ignore_tkr(k) arg2; \
end function ;

! vector(r) function f(vector(r), i)
#define ELEM_FUNC_VRVRI0(VKIND) \
elemental vector(real(VKIND)) function elem_func_vr##VKIND##vr##VKIND##i0(arg1, arg2); \
vector(real(VKIND)), intent(in) :: arg1; \
integer(8), intent(in) :: arg2; \
!dir$ ignore_tkr(k) arg2; \
end function ;

! The following macros are specific for the vec_convert(v, mold) intrinsics as
! the argument keywords are different from the other vector intrinsics.
!
Expand Down Expand Up @@ -203,10 +246,16 @@ pure vector(real(VKIND)) function func_vec_convert_vr##VKIND##vi##vr##VKIND(v, m
ELEM_FUNC_IVRVR(4,4) ELEM_FUNC_IVRVR(4,8)
ELEM_FUNC_VRVII(4) ELEM_FUNC_VRVII(8)
ELEM_FUNC_VRVUI(4) ELEM_FUNC_VRVUI(8)

ELEM_FUNC_VIVII0(1) ELEM_FUNC_VIVII0(2) ELEM_FUNC_VIVII0(4) ELEM_FUNC_VIVII0(8)
ELEM_FUNC_VUVUI0(1) ELEM_FUNC_VUVUI0(2) ELEM_FUNC_VUVUI0(4) ELEM_FUNC_VUVUI0(8)
ELEM_FUNC_VRVRI0(4) ELEM_FUNC_VRVRI0(8)

#undef FUNC_VEC_CONVERT_VRVIVR
#undef FUNC_VEC_CONVERT_VUVIVU
#undef FUNC_VEC_CONVERT_VIVIVI
#undef ELEM_FUNC_VRVRI0
#undef ELEM_FUNC_VUVUI0
#undef ELEM_FUNC_VIVII0
#undef ELEM_FUNC_RVRI
#undef ELEM_FUNC_VRVUI
#undef ELEM_FUNC_IVII
Expand Down Expand Up @@ -618,13 +667,16 @@ end function func_r8r8i
end interface mtfsfi
public :: mtfsfi

!-------------------------
! vector function(vector)
!-------------------------
!-----------------------------
! vector function(vector/i/r)
!-----------------------------
#define VI_VI(NAME, VKIND) __ppc_##NAME##_vi##VKIND##vi##VKIND
#define VU_VU(NAME, VKIND) __ppc_##NAME##_vu##VKIND##vu##VKIND
#define VR_VR_2(NAME, VKIND1, VKIND2) __ppc_##NAME##_vr##VKIND1##vr##VKIND2
#define VR_VR(NAME, VKIND) VR_VR_2(NAME, VKIND, VKIND)
#define VI_I_2(NAME, RKIND, VKIND) __ppc_##NAME##_vi##RKIND##i##VKIND
#define VI_I(NAME, VKIND) VI_I_2(NAME, VKIND, VKIND)
#define VR_R(NAME, VKIND) __ppc_##NAME##_vr##VKIND##r##VKIND

#define VEC_VI_VI(NAME, VKIND) \
procedure(elem_func_vi##VKIND##vi##VKIND) :: VI_VI(NAME, VKIND);
Expand All @@ -633,6 +685,11 @@ end function func_r8r8i
#define VEC_VR_VR_2(NAME, VKIND1, VKIND2) \
procedure(elem_func_vr##VKIND1##vr##VKIND2) :: VR_VR_2(NAME, VKIND1, VKIND2);
#define VEC_VR_VR(NAME, VKIND) VEC_VR_VR_2(NAME, VKIND, VKIND)
#define VEC_VI_I_2(NAME, RKIND, VKIND) \
procedure(elem_func_vi##RKIND##i##VKIND) :: VI_I_2(NAME, RKIND, VKIND);
#define VEC_VI_I(NAME, VKIND) VEC_VI_I_2(NAME, VKIND, VKIND)
#define VEC_VR_R(NAME, VKIND) \
procedure(elem_func_vr##VKIND##r##VKIND) :: VR_R(NAME, VKIND);

! vec_abs
VEC_VI_VI(vec_abs,1) VEC_VI_VI(vec_abs,2) VEC_VI_VI(vec_abs,4) VEC_VI_VI(vec_abs,8)
Expand Down Expand Up @@ -664,10 +721,32 @@ end function func_r8r8i
end interface
public vec_cvspbf16

! vec_splats
VEC_VI_I(vec_splats,1) VEC_VI_I(vec_splats,2) VEC_VI_I(vec_splats,4) VEC_VI_I(vec_splats,8)
VEC_VR_R(vec_splats,4) VEC_VR_R(vec_splats,8)
interface vec_splats
procedure :: VI_I(vec_splats,1), VI_I(vec_splats,2), VI_I(vec_splats,4), VI_I(vec_splats,8)
procedure :: VR_R(vec_splats,4), VR_R(vec_splats,8)
end interface vec_splats
public :: vec_splats

! vec_splat_32
VEC_VI_I_2(vec_splat_s32_,4,1) VEC_VI_I_2(vec_splat_s32_,4,2) VEC_VI_I_2(vec_splat_s32_,4,4) VEC_VI_I_2(vec_splat_s32_,4,8)
interface vec_splat_s32
procedure :: VI_I_2(vec_splat_s32_,4,1), VI_I_2(vec_splat_s32_,4,2), VI_I_2(vec_splat_s32_,4,4), VI_I_2(vec_splat_s32_,4,8)
end interface vec_splat_s32
public :: vec_splat_s32

#undef VEC_VR_R
#undef VEC_VI_I
#undef VEC_VI_I_2
#undef VEC_VR_VR
#undef VEC_VR_VR_2
#undef VEC_VU_VU
#undef VEC_VI_VI
#undef VR_R
#undef VI_I
#undef VI_I_2
#undef VR_VR
#undef VR_VR_2
#undef VU_VU
Expand Down Expand Up @@ -1220,11 +1299,20 @@ end function func_r8r8i
! the `ignore_tkr' directive.
#define VR_VI_I(NAME, VKIND) __ppc_##NAME##_vr##VKIND##vi##VKIND##i0
#define VR_VU_I(NAME, VKIND) __ppc_##NAME##_vr##VKIND##vu##VKIND##i0
#define VI_VI_I0(NAME, VKIND) __ppc_##NAME##_vi##VKIND##vi##VKIND##i0
#define VU_VU_I0(NAME, VKIND) __ppc_##NAME##_vu##VKIND##vu##VKIND##i0
#define VR_VR_I0(NAME, VKIND) __ppc_##NAME##_vr##VKIND##vr##VKIND##i0

#define VEC_VR_VI_I(NAME, VKIND) \
procedure(elem_func_vr##VKIND##vi##VKIND##i) :: VR_VI_I(NAME, VKIND);
#define VEC_VR_VU_I(NAME, VKIND) \
procedure(elem_func_vr##VKIND##vu##VKIND##i) :: VR_VU_I(NAME, VKIND);
#define VEC_VI_VI_I0(NAME, VKIND) \
procedure(elem_func_vi##VKIND##vi##VKIND##i0) :: VI_VI_I0(NAME, VKIND);
#define VEC_VU_VU_I0(NAME, VKIND) \
procedure(elem_func_vu##VKIND##vu##VKIND##i0) :: VU_VU_I0(NAME, VKIND);
#define VEC_VR_VR_I0(NAME, VKIND) \
procedure(elem_func_vr##VKIND##vr##VKIND##i0) :: VR_VR_I0(NAME, VKIND);

! vec_ctf
VEC_VR_VI_I(vec_ctf,4) VEC_VR_VI_I(vec_ctf,8)
Expand All @@ -1235,8 +1323,25 @@ end function func_r8r8i
end interface vec_ctf
public :: vec_ctf

! vec_splat
VEC_VI_VI_I0(vec_splat,1) VEC_VI_VI_I0(vec_splat,2) VEC_VI_VI_I0(vec_splat,4) VEC_VI_VI_I0(vec_splat,8)
VEC_VU_VU_I0(vec_splat,1) VEC_VU_VU_I0(vec_splat,2) VEC_VU_VU_I0(vec_splat,4) VEC_VU_VU_I0(vec_splat,8)
VEC_VR_VR_I0(vec_splat,4) VEC_VR_VR_I0(vec_splat,8)
interface vec_splat
procedure :: VI_VI_I0(vec_splat,1), VI_VI_I0(vec_splat,2), VI_VI_I0(vec_splat,4), VI_VI_I0(vec_splat,8)
procedure :: VU_VU_I0(vec_splat,1), VU_VU_I0(vec_splat,2), VU_VU_I0(vec_splat,4), VU_VU_I0(vec_splat,8)
procedure :: VR_VR_I0(vec_splat,4), VR_VR_I0(vec_splat,8)
end interface vec_splat
public :: vec_splat

#undef VEC_VR_VR_I0
#undef VEC_VU_VU_I0
#undef VEC_VI_VI_I0
#undef VEC_VR_VU_I
#undef VEC_VR_VI_I
#undef VR_VR_I0
#undef VU_VU_I0
#undef VI_VI_I0
#undef VR_VU_I
#undef VR_VI_I

Expand Down
49 changes: 49 additions & 0 deletions flang/test/Lower/PowerPC/ppc-vec-splat-elem-order.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
! RUN: bbc -emit-fir %s -fno-ppc-native-vector-element-order -o - | FileCheck --check-prefixes="FIR" %s
! RUN: %flang_fc1 -emit-llvm %s -fno-ppc-native-vector-element-order -o - | FileCheck --check-prefixes="LLVMIR" %s
! REQUIRES: target=powerpc{{.*}}

! CHECK-LABEL: vec_splat_testf32i64
subroutine vec_splat_testf32i64(x)
vector(real(4)) :: x, y
y = vec_splat(x, 0_8)
! FIR: %[[x:.*]] = fir.load %arg0 : !fir.ref<!fir.vector<4:f32>>
! FIR: %[[idx:.*]] = arith.constant 0 : i64
! FIR: %[[vx:.*]] = fir.convert %[[x]] : (!fir.vector<4:f32>) -> vector<4xf32>
! FIR: %[[c:.*]] = arith.constant 4 : i64
! FIR: %[[u:.*]] = llvm.urem %[[idx]], %[[c]] : i64
! FIR: %[[c2:.*]] = arith.constant 3 : i64
! FIR: %[[sub:.*]] = llvm.sub %[[c2]], %[[u]] : i64
! FIR: %[[ele:.*]] = vector.extractelement %[[vx]][%[[sub]] : i64] : vector<4xf32>
! FIR: %[[vy:.*]] = vector.splat %[[ele]] : vector<4xf32>
! FIR: %[[y:.*]] = fir.convert %[[vy]] : (vector<4xf32>) -> !fir.vector<4:f32>
! FIR: fir.store %[[y]] to %{{[0-9]}} : !fir.ref<!fir.vector<4:f32>>

! LLVMIR: %[[x:.*]] = load <4 x float>, ptr %{{[0-9]}}, align 16
! LLVMIR: %[[ele:.*]] = extractelement <4 x float> %[[x]], i64 3
! LLVMIR: %[[ins:.*]] = insertelement <4 x float> undef, float %[[ele]], i32 0
! LLVMIR: %[[y:.*]] = shufflevector <4 x float> %[[ins]], <4 x float> undef, <4 x i32> zeroinitializer
! LLVMIR: store <4 x float> %[[y]], ptr %{{[0-9]}}, align 16
end subroutine vec_splat_testf32i64

! CHECK-LABEL: vec_splat_testu8i16
subroutine vec_splat_testu8i16(x)
vector(unsigned(1)) :: x, y
y = vec_splat(x, 0_2)
! FIR: %[[x:.*]] = fir.load %arg0 : !fir.ref<!fir.vector<16:ui8>>
! FIR: %[[idx:.*]] = arith.constant 0 : i16
! FIR: %[[vx:.*]] = fir.convert %[[x]] : (!fir.vector<16:ui8>) -> vector<16xi8>
! FIR: %[[c:.*]] = arith.constant 16 : i16
! FIR: %[[u:.*]] = llvm.urem %[[idx]], %[[c]] : i16
! FIR: %[[c2:.*]] = arith.constant 15 : i16
! FIR: %[[sub:.*]] = llvm.sub %[[c2]], %[[u]] : i16
! FIR: %[[ele:.*]] = vector.extractelement %[[vx]][%[[sub]] : i16] : vector<16xi8>
! FIR: %[[vy:.*]] = vector.splat %[[ele]] : vector<16xi8>
! FIR: %[[y:.*]] = fir.convert %[[vy]] : (vector<16xi8>) -> !fir.vector<16:ui8>
! FIR: fir.store %[[y]] to %{{[0-9]}} : !fir.ref<!fir.vector<16:ui8>>

! LLVMIR: %[[x:.*]] = load <16 x i8>, ptr %{{[0-9]}}, align 16
! LLVMIR: %[[ele:.*]] = extractelement <16 x i8> %[[x]], i16 15
! LLVMIR: %[[ins:.*]] = insertelement <16 x i8> undef, i8 %[[ele]], i32 0
! LLVMIR: %[[y:.*]] = shufflevector <16 x i8> %[[ins]], <16 x i8> undef, <16 x i32> zeroinitializer
! LLVMIR: store <16 x i8> %[[y]], ptr %{{[0-9]}}, align 16
end subroutine vec_splat_testu8i16

0 comments on commit f50eaea

Please sign in to comment.