Skip to content

Commit

Permalink
[flang] Add PowerPC vec_sl, vec_sld, vec_sldw, vec_sll, vec_slo, vec_…
Browse files Browse the repository at this point in the history
…srl and vec_sro intrinsic

Co-authored-by: pscoro

Differential Revision: https://reviews.llvm.org/D154563
  • Loading branch information
kkwli committed Jul 12, 2023
1 parent 18ad80e commit 10124b3
Show file tree
Hide file tree
Showing 6 changed files with 4,918 additions and 8 deletions.
25 changes: 24 additions & 1 deletion flang/include/flang/Optimizer/Builder/PPCIntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,26 @@ namespace fir {

/// Enums used to templatize vector intrinsic function generators. Enum does
/// not contain every vector intrinsic, only intrinsics that share generators.
enum class VecOp { Add, And, Anyge, Cmpge, Cmpgt, Cmple, Cmplt, Mul, Sub, Xor };
enum class VecOp {
Add,
And,
Anyge,
Cmpge,
Cmpgt,
Cmple,
Cmplt,
Mul,
Sl,
Sld,
Sldw,
Sll,
Slo,
Sr,
Srl,
Sro,
Sub,
Xor
};

// Wrapper struct to encapsulate information for a vector type. Preserves
// sign of eleTy if eleTy is signed/unsigned integer. Helps with vector type
Expand Down Expand Up @@ -96,6 +115,10 @@ struct PPCIntrinsicLibrary : IntrinsicLibrary {
template <VecOp>
fir::ExtendedValue genVecAnyCompare(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args);

template <VecOp>
fir::ExtendedValue genVecShift(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
};

const IntrinsicHandler *findPPCIntrinsicHandler(llvm::StringRef name);
Expand Down
170 changes: 170 additions & 0 deletions flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,46 @@ static constexpr IntrinsicHandler ppcHandlers[]{
&PI::genVecAddAndMulSubXor<VecOp::Mul>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sl",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sl>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sld",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sld>),
{{{"arg1", asValue}, {"arg2", asValue}, {"arg3", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sldw",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sldw>),
{{{"arg1", asValue}, {"arg2", asValue}, {"arg3", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sll",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sll>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_slo",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Slo>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sr",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sr>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_srl",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Srl>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sro",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecShift<VecOp::Sro>),
{{{"arg1", asValue}, {"arg2", asValue}}},
/*isElemental=*/true},
{"__ppc_vec_sub",
static_cast<IntrinsicLibrary::ExtendedGenerator>(
&PI::genVecAddAndMulSubXor<VecOp::Sub>),
Expand Down Expand Up @@ -641,4 +681,134 @@ PPCIntrinsicLibrary::genVecCmp(mlir::Type resultType,
return res;
}

// VEC_SL, VEC_SLD, VEC_SLDW, VEC_SLL, VEC_SLO, VEC_SR, VEC_SRL, VEC_SRO
template <VecOp vop>
fir::ExtendedValue
PPCIntrinsicLibrary::genVecShift(mlir::Type resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
auto context{builder.getContext()};
auto argBases{getBasesForArgs(args)};
auto argTypes{getTypesForArgs(argBases)};

llvm::SmallVector<VecTypeInfo, 2> vecTyInfoArgs;
vecTyInfoArgs.push_back(getVecTypeFromFir(argBases[0]));
vecTyInfoArgs.push_back(getVecTypeFromFir(argBases[1]));

// Convert the first two arguments to MLIR vectors
llvm::SmallVector<mlir::Type, 2> mlirTyArgs;
mlirTyArgs.push_back(vecTyInfoArgs[0].toMlirVectorType(context));
mlirTyArgs.push_back(vecTyInfoArgs[1].toMlirVectorType(context));

llvm::SmallVector<mlir::Value, 2> mlirVecArgs;
mlirVecArgs.push_back(builder.createConvert(loc, mlirTyArgs[0], argBases[0]));
mlirVecArgs.push_back(builder.createConvert(loc, mlirTyArgs[1], argBases[1]));

mlir::Value shftRes{nullptr};

if (vop == VecOp::Sl || vop == VecOp::Sr) {
assert(args.size() == 2);
// Construct the mask
auto width{
mlir::dyn_cast<mlir::IntegerType>(vecTyInfoArgs[1].eleTy).getWidth()};
auto vecVal{builder.createIntegerConstant(
loc, getConvertedElementType(context, vecTyInfoArgs[0].eleTy), width)};
auto mask{
builder.create<mlir::vector::BroadcastOp>(loc, mlirTyArgs[1], vecVal)};
auto shft{builder.create<mlir::arith::RemUIOp>(loc, mlirVecArgs[1], mask)};

mlir::Value res{nullptr};
if (vop == VecOp::Sr)
res = builder.create<mlir::arith::ShRUIOp>(loc, mlirVecArgs[0], shft);
else if (vop == VecOp::Sl)
res = builder.create<mlir::arith::ShLIOp>(loc, mlirVecArgs[0], shft);

shftRes = builder.createConvert(loc, argTypes[0], res);
} else if (vop == VecOp::Sll || vop == VecOp::Slo || vop == VecOp::Srl ||
vop == VecOp::Sro) {
assert(args.size() == 2);

// Bitcast to vector<4xi32>
auto bcVecTy{mlir::VectorType::get(4, builder.getIntegerType(32))};
if (mlirTyArgs[0] != bcVecTy)
mlirVecArgs[0] =
builder.create<mlir::vector::BitCastOp>(loc, bcVecTy, mlirVecArgs[0]);
if (mlirTyArgs[1] != bcVecTy)
mlirVecArgs[1] =
builder.create<mlir::vector::BitCastOp>(loc, bcVecTy, mlirVecArgs[1]);

llvm::StringRef funcName;
switch (vop) {
case VecOp::Srl:
funcName = "llvm.ppc.altivec.vsr";
break;
case VecOp::Sro:
funcName = "llvm.ppc.altivec.vsro";
break;
case VecOp::Sll:
funcName = "llvm.ppc.altivec.vsl";
break;
case VecOp::Slo:
funcName = "llvm.ppc.altivec.vslo";
break;
default:
llvm_unreachable("unknown vector shift operation");
}
auto funcTy{genFuncType<Ty::IntegerVector<4>, Ty::IntegerVector<4>,
Ty::IntegerVector<4>>(context, builder)};
mlir::func::FuncOp funcOp{builder.addNamedFunction(loc, funcName, funcTy)};
auto callOp{builder.create<fir::CallOp>(loc, funcOp, mlirVecArgs)};

// If the result vector type is different from the original type, need
// to convert to mlir vector, bitcast and then convert back to fir vector.
if (callOp.getResult(0).getType() != argTypes[0]) {
auto res = builder.createConvert(loc, bcVecTy, callOp.getResult(0));
res = builder.create<mlir::vector::BitCastOp>(loc, mlirTyArgs[0], res);
shftRes = builder.createConvert(loc, argTypes[0], res);
} else {
shftRes = callOp.getResult(0);
}
} else if (vop == VecOp::Sld || vop == VecOp::Sldw) {
assert(args.size() == 3);
auto constIntOp =
mlir::dyn_cast<mlir::arith::ConstantOp>(argBases[2].getDefiningOp())
.getValue()
.dyn_cast_or_null<mlir::IntegerAttr>();
assert(constIntOp && "expected integer constant argument");

// Bitcast to vector<16xi8>
auto vi8Ty{mlir::VectorType::get(16, builder.getIntegerType(8))};
if (mlirTyArgs[0] != vi8Ty) {
mlirVecArgs[0] =
builder.create<mlir::LLVM::BitcastOp>(loc, vi8Ty, mlirVecArgs[0])
.getResult();
mlirVecArgs[1] =
builder.create<mlir::LLVM::BitcastOp>(loc, vi8Ty, mlirVecArgs[1])
.getResult();
}

// Construct the mask for shuffling
auto shiftVal{constIntOp.getInt()};
if (vop == VecOp::Sldw)
shiftVal = shiftVal << 2;
shiftVal &= 0xF;
llvm::SmallVector<int64_t, 16> mask;
for (int i = 16; i < 32; ++i)
mask.push_back(i - shiftVal);

// Shuffle with mask
shftRes = builder.create<mlir::vector::ShuffleOp>(loc, mlirVecArgs[1],
mlirVecArgs[0], mask);

// Bitcast to the original type
if (shftRes.getType() != mlirTyArgs[0])
shftRes =
builder.create<mlir::LLVM::BitcastOp>(loc, mlirTyArgs[0], shftRes);

return builder.createConvert(loc, resultType, shftRes);
} else
llvm_unreachable("Invalid vector operation for generator");

return shftRes;
}

} // namespace fir
6 changes: 6 additions & 0 deletions flang/lib/Semantics/check-call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,12 @@ bool CheckPPCIntrinsic(const Symbol &generic, const Symbol &specific,
return CheckArgumentIsConstantExprInRange(actuals, 0, 0, 7, messages) &&
CheckArgumentIsConstantExprInRange(actuals, 1, 0, 15, messages);
}
if (specific.name().ToString().compare(0, 14, "__ppc_vec_sld_") == 0) {
return CheckArgumentIsConstantExprInRange(actuals, 2, 0, 15, messages);
}
if (specific.name().ToString().compare(0, 15, "__ppc_vec_sldw_") == 0) {
return CheckArgumentIsConstantExprInRange(actuals, 2, 0, 3, messages);
}
return false;
}

Expand Down

0 comments on commit 10124b3

Please sign in to comment.