217 changes: 157 additions & 60 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ static bool isStaticallyPresent(const fir::ExtendedValue &exv) {
return !isStaticallyAbsent(exv);
}

/// IEEE module procedure names not yet implemented for genModuleProcTODO.
static constexpr char ieee_int[] = "ieee_int";
static constexpr char ieee_get_underflow_mode[] = "ieee_get_underflow_mode";
static constexpr char ieee_next_after[] = "ieee_next_after";
static constexpr char ieee_next_down[] = "ieee_next_down";
static constexpr char ieee_next_up[] = "ieee_next_up";
static constexpr char ieee_real[] = "ieee_real";
static constexpr char ieee_rem[] = "ieee_rem";
static constexpr char ieee_rint[] = "ieee_rint";
static constexpr char ieee_set_underflow_mode[] = "ieee_set_underflow_mode";

using I = IntrinsicLibrary;

/// Flag to indicate that an intrinsic argument has to be handled as
Expand Down Expand Up @@ -321,6 +332,8 @@ static constexpr IntrinsicHandler handlers[]{
{"radix", asValue, handleDynamicOptional}}},
/*isElemental=*/false},
{"ieee_get_status", &I::genIeeeGetOrSetStatus</*isGet=*/true>},
{"ieee_get_underflow_mode", &I::genModuleProcTODO<ieee_get_underflow_mode>},
{"ieee_int", &I::genModuleProcTODO<ieee_int>},
{"ieee_is_finite", &I::genIeeeIsFinite},
{"ieee_is_nan", &I::genIeeeIsNan},
{"ieee_is_negative", &I::genIeeeIsNegative},
Expand All @@ -342,12 +355,18 @@ static constexpr IntrinsicHandler handlers[]{
&I::genIeeeMaxMin</*isMax=*/false, /*isNum=*/true, /*isMag=*/false>},
{"ieee_min_num_mag",
&I::genIeeeMaxMin</*isMax=*/false, /*isNum=*/true, /*isMag=*/true>},
{"ieee_next_after", &I::genModuleProcTODO<ieee_next_after>},
{"ieee_next_down", &I::genModuleProcTODO<ieee_next_down>},
{"ieee_next_up", &I::genModuleProcTODO<ieee_next_up>},
{"ieee_quiet_eq", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::OEQ>},
{"ieee_quiet_ge", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::OGE>},
{"ieee_quiet_gt", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::OGT>},
{"ieee_quiet_le", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::OLE>},
{"ieee_quiet_lt", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::OLT>},
{"ieee_quiet_ne", &I::genIeeeQuietCompare<mlir::arith::CmpFPredicate::UNE>},
{"ieee_real", &I::genModuleProcTODO<ieee_real>},
{"ieee_rem", &I::genModuleProcTODO<ieee_rem>},
{"ieee_rint", &I::genModuleProcTODO<ieee_rint>},
{"ieee_round_eq", &I::genIeeeTypeCompare<mlir::arith::CmpIPredicate::eq>},
{"ieee_round_ne", &I::genIeeeTypeCompare<mlir::arith::CmpIPredicate::ne>},
{"ieee_set_flag", &I::genIeeeSetFlagOrHaltingMode</*isFlag=*/true>},
Expand All @@ -360,6 +379,7 @@ static constexpr IntrinsicHandler handlers[]{
{"radix", asValue, handleDynamicOptional}}},
/*isElemental=*/false},
{"ieee_set_status", &I::genIeeeGetOrSetStatus</*isGet=*/false>},
{"ieee_set_underflow_mode", &I::genModuleProcTODO<ieee_set_underflow_mode>},
{"ieee_signaling_eq",
&I::genIeeeSignalingCompare<mlir::arith::CmpFPredicate::OEQ>},
{"ieee_signaling_ge",
Expand Down Expand Up @@ -1493,17 +1513,11 @@ static_assert(mathOps.Verify() && "map must be sorted");
/// \p bestMatchDistance specifies the FunctionDistance between
/// the requested operation and the non-exact match.
static const MathOperation *
searchMathOperation(fir::FirOpBuilder &builder, llvm::StringRef name,
searchMathOperation(fir::FirOpBuilder &builder,
const IntrinsicHandlerEntry::RuntimeGeneratorRange &range,
mlir::FunctionType funcType,
const MathOperation **bestNearMatch,
FunctionDistance &bestMatchDistance) {
auto range = mathOps.equal_range(name);
auto mod = builder.getModule();

// Search ppcMathOps only if targetting PowerPC arch
if (fir::getTargetTriple(mod).isPPC() && range.first == range.second) {
range = checkPPCMathOperationsRange(name);
}
for (auto iter = range.first; iter != range.second && iter; ++iter) {
const auto &impl = *iter;
auto implType = impl.typeGenerator(builder.getContext(), builder);
Expand Down Expand Up @@ -1649,8 +1663,46 @@ llvm::StringRef genericName(llvm::StringRef specificName) {
return name.drop_back(name.size() - size);
}

std::optional<IntrinsicHandlerEntry::RuntimeGeneratorRange>
lookupRuntimeGenerator(llvm::StringRef name, bool isPPCTarget) {
if (auto range = mathOps.equal_range(name); range.first != range.second)
return std::make_optional<IntrinsicHandlerEntry::RuntimeGeneratorRange>(
range);
// Search ppcMathOps only if targetting PowerPC arch
if (isPPCTarget)
if (auto range = checkPPCMathOperationsRange(name);
range.first != range.second)
return std::make_optional<IntrinsicHandlerEntry::RuntimeGeneratorRange>(
range);
return std::nullopt;
}

std::optional<IntrinsicHandlerEntry>
lookupIntrinsicHandler(fir::FirOpBuilder &builder,
llvm::StringRef intrinsicName,
std::optional<mlir::Type> resultType) {
llvm::StringRef name = genericName(intrinsicName);
if (const IntrinsicHandler *handler = findIntrinsicHandler(name))
return std::make_optional<IntrinsicHandlerEntry>(handler);
bool isPPCTarget = fir::getTargetTriple(builder.getModule()).isPPC();
// If targeting PowerPC, check PPC intrinsic handlers.
if (isPPCTarget)
if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name))
return std::make_optional<IntrinsicHandlerEntry>(ppcHandler);
// Subroutines should have a handler.
if (!resultType)
return std::nullopt;
// Try the runtime if no special handler was defined for the
// intrinsic being called. Maths runtime only has numerical elemental.
if (auto runtimeGeneratorRange = lookupRuntimeGenerator(name, isPPCTarget))
return std::make_optional<IntrinsicHandlerEntry>(*runtimeGeneratorRange);
return std::nullopt;
}

/// Generate a TODO error message for an as yet unimplemented intrinsic.
void crashOnMissingIntrinsic(mlir::Location loc, llvm::StringRef name) {
void crashOnMissingIntrinsic(mlir::Location loc,
llvm::StringRef intrinsicName) {
llvm::StringRef name = genericName(intrinsicName);
if (isIntrinsicModuleProcedure(name))
TODO(loc, "intrinsic module procedure: " + llvm::Twine(name));
else if (isCoarrayIntrinsic(name))
Expand Down Expand Up @@ -1782,46 +1834,33 @@ invokeHandler(IntrinsicLibrary::DualGenerator generator,
return std::invoke(generator, lib, resultType, args);
}

std::pair<fir::ExtendedValue, bool>
IntrinsicLibrary::genIntrinsicCall(llvm::StringRef specificName,
std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
llvm::StringRef name = genericName(specificName);
if (const IntrinsicHandler *handler = findIntrinsicHandler(name)) {
bool outline = handler->outline || outlineAllIntrinsics;
return {Fortran::common::visit(
[&](auto &generator) -> fir::ExtendedValue {
return invokeHandler(generator, *handler, resultType, args,
outline, *this);
},
handler->generator),
this->resultMustBeFreed};
}

// If targeting PowerPC, check PPC intrinsic handlers.
auto mod = builder.getModule();
if (fir::getTargetTriple(mod).isPPC()) {
if (const IntrinsicHandler *ppcHandler = findPPCIntrinsicHandler(name)) {
bool outline = ppcHandler->outline || outlineAllIntrinsics;
return {Fortran::common::visit(
[&](auto &generator) -> fir::ExtendedValue {
return invokeHandler(generator, *ppcHandler, resultType,
args, outline, *this);
},
ppcHandler->generator),
this->resultMustBeFreed};
}
}

// Try the runtime if no special handler was defined for the
// intrinsic being called. Maths runtime only has numerical elemental.
// No optional arguments are expected at this point, the code will
// crash if it gets absent optional.
static std::pair<fir::ExtendedValue, bool> genIntrinsicCallHelper(
const IntrinsicHandler *handler, std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args, IntrinsicLibrary &lib) {
assert(handler && "must be set");
bool outline = handler->outline || outlineAllIntrinsics;
return {Fortran::common::visit(
[&](auto &generator) -> fir::ExtendedValue {
return invokeHandler(generator, *handler, resultType, args,
outline, lib);
},
handler->generator),
lib.resultMustBeFreed};
}

if (!resultType)
// Subroutine should have a handler, they are likely missing for now.
crashOnMissingIntrinsic(loc, name);
static IntrinsicLibrary::RuntimeCallGenerator getRuntimeCallGeneratorHelper(
const IntrinsicHandlerEntry::RuntimeGeneratorRange &, mlir::FunctionType,
fir::FirOpBuilder &, mlir::Location);

static std::pair<fir::ExtendedValue, bool> genIntrinsicCallHelper(
const IntrinsicHandlerEntry::RuntimeGeneratorRange &range,
std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args, IntrinsicLibrary &lib) {
assert(resultType.has_value() && "RuntimeGenerator are for functions only");
assert(range.first != nullptr && "range should not be empty");
fir::FirOpBuilder &builder = lib.builder;
mlir::Location loc = lib.loc;
llvm::StringRef name = range.first->key;
// FIXME: using toValue to get the type won't work with array arguments.
llvm::SmallVector<mlir::Value> mlirArgs;
for (const fir::ExtendedValue &extendedVal : args) {
Expand All @@ -1836,10 +1875,39 @@ IntrinsicLibrary::genIntrinsicCall(llvm::StringRef specificName,
getFunctionType(*resultType, mlirArgs, builder);

IntrinsicLibrary::RuntimeCallGenerator runtimeCallGenerator =
getRuntimeCallGenerator(name, soughtFuncType);
return {genElementalCall(runtimeCallGenerator, name, *resultType, args,
/*outline=*/outlineAllIntrinsics),
resultMustBeFreed};
getRuntimeCallGeneratorHelper(range, soughtFuncType, builder, loc);
return {lib.genElementalCall(runtimeCallGenerator, name, *resultType, args,
/*outline=*/outlineAllIntrinsics),
lib.resultMustBeFreed};
}

std::pair<fir::ExtendedValue, bool>
genIntrinsicCall(fir::FirOpBuilder &builder, mlir::Location loc,
const IntrinsicHandlerEntry &intrinsic,
std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args,
Fortran::lower::AbstractConverter *converter) {
IntrinsicLibrary library{builder, loc, converter};
return std::visit(
[&](auto handler) -> auto {
return genIntrinsicCallHelper(handler, resultType, args, library);
},
intrinsic.entry);
}

std::pair<fir::ExtendedValue, bool>
IntrinsicLibrary::genIntrinsicCall(llvm::StringRef specificName,
std::optional<mlir::Type> resultType,
llvm::ArrayRef<fir::ExtendedValue> args) {
std::optional<IntrinsicHandlerEntry> intrinsic =
lookupIntrinsicHandler(builder, specificName, resultType);
if (!intrinsic.has_value())
crashOnMissingIntrinsic(loc, specificName);
return std::visit(
[&](auto handler) -> auto {
return genIntrinsicCallHelper(handler, resultType, args, *this);
},
intrinsic->entry);
}

mlir::Value
Expand Down Expand Up @@ -2082,19 +2150,19 @@ fir::ExtendedValue IntrinsicLibrary::outlineInExtendedWrapper(
return mlir::Value{};
}

IntrinsicLibrary::RuntimeCallGenerator
IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
mlir::FunctionType soughtFuncType) {
mlir::FunctionType actualFuncType;
const MathOperation *mathOp = nullptr;

static IntrinsicLibrary::RuntimeCallGenerator getRuntimeCallGeneratorHelper(
const IntrinsicHandlerEntry::RuntimeGeneratorRange &range,
mlir::FunctionType soughtFuncType, fir::FirOpBuilder &builder,
mlir::Location loc) {
assert(range.first != nullptr && "range should not be empty");
llvm::StringRef name = range.first->key;
// Look for a dedicated math operation generator, which
// normally produces a single MLIR operation implementing
// the math operation.
const MathOperation *bestNearMatch = nullptr;
FunctionDistance bestMatchDistance;
mathOp = searchMathOperation(builder, name, soughtFuncType, &bestNearMatch,
bestMatchDistance);
const MathOperation *mathOp = searchMathOperation(
builder, range, soughtFuncType, &bestNearMatch, bestMatchDistance);
if (!mathOp && bestNearMatch) {
// Use the best near match, optionally issuing an error,
// if types conversions cause precision loss.
Expand All @@ -2109,7 +2177,8 @@ IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
crashOnMissingIntrinsic(loc, nameAndType);
}

actualFuncType = mathOp->typeGenerator(builder.getContext(), builder);
mlir::FunctionType actualFuncType =
mathOp->typeGenerator(builder.getContext(), builder);

assert(actualFuncType.getNumResults() == soughtFuncType.getNumResults() &&
actualFuncType.getNumInputs() == soughtFuncType.getNumInputs() &&
Expand All @@ -2128,6 +2197,17 @@ IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
};
}

IntrinsicLibrary::RuntimeCallGenerator
IntrinsicLibrary::getRuntimeCallGenerator(llvm::StringRef name,
mlir::FunctionType soughtFuncType) {
bool isPPCTarget = fir::getTargetTriple(builder.getModule()).isPPC();
std::optional<IntrinsicHandlerEntry::RuntimeGeneratorRange> range =
lookupRuntimeGenerator(name, isPPCTarget);
if (!range.has_value())
crashOnMissingIntrinsic(loc, name);
return getRuntimeCallGeneratorHelper(*range, soughtFuncType, builder, loc);
}

mlir::SymbolRefAttr IntrinsicLibrary::getUnrestrictedIntrinsicSymbolRefAttr(
llvm::StringRef name, mlir::FunctionType signature) {
// Unrestricted intrinsics signature follows implicit rules: argument
Expand Down Expand Up @@ -2214,6 +2294,12 @@ mlir::Value IntrinsicLibrary::genConversion(mlir::Type resultType,
return builder.convertWithSemantics(loc, resultType, args[0]);
}

template <const char *intrinsicName>
void IntrinsicLibrary::genModuleProcTODO(
llvm::ArrayRef<fir::ExtendedValue> args) {
crashOnMissingIntrinsic(loc, intrinsicName);
}

// ABORT
void IntrinsicLibrary::genAbort(llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 0);
Expand Down Expand Up @@ -7076,6 +7162,17 @@ getIntrinsicArgumentLowering(llvm::StringRef specificName) {
return nullptr;
}

const IntrinsicArgumentLoweringRules *
IntrinsicHandlerEntry::getArgumentLoweringRules() const {
if (const IntrinsicHandler *const *handler =
std::get_if<const IntrinsicHandler *>(&entry)) {
assert(*handler);
if (!(*handler)->argLoweringRules.hasDefaultRules())
return &(*handler)->argLoweringRules;
}
return nullptr;
}

/// Return how argument \p argName should be lowered given the rules for the
/// intrinsic function.
fir::ArgLoweringRule
Expand Down
1 change: 1 addition & 0 deletions libc/config/darwin/arm/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atan2f
libc.src.math.atanf
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.copysign
libc.src.math.copysignf
libc.src.math.copysignl
Expand Down
1 change: 1 addition & 0 deletions libc/config/gpu/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atanf
libc.src.math.atanh
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.ceil
libc.src.math.ceilf
libc.src.math.copysign
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/aarch64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atan2f
libc.src.math.atanf
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.ceil
libc.src.math.ceilf
libc.src.math.ceill
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/arm/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atan2f
libc.src.math.atanf
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.ceil
libc.src.math.ceilf
libc.src.math.ceill
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/riscv/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atan2f
libc.src.math.atanf
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.ceil
libc.src.math.ceilf
libc.src.math.ceill
Expand Down
1 change: 1 addition & 0 deletions libc/config/linux/x86_64/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.canonicalize
libc.src.math.canonicalizef
libc.src.math.canonicalizel
libc.src.math.cbrtf
libc.src.math.ceil
libc.src.math.ceilf
libc.src.math.ceill
Expand Down
1 change: 1 addition & 0 deletions libc/config/windows/entrypoints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ set(TARGET_LIBM_ENTRYPOINTS
libc.src.math.atan2f
libc.src.math.atanf
libc.src.math.atanhf
libc.src.math.cbrtf
libc.src.math.copysign
libc.src.math.copysignf
libc.src.math.copysignl
Expand Down
2 changes: 1 addition & 1 deletion libc/docs/math/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ Higher Math Functions
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| atanpi | | | | | | 7.12.4.10 | F.10.1.10 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| cbrt | | | | | | 7.12.7.1 | F.10.4.1 |
| cbrt | |check| | | | | | 7.12.7.1 | F.10.4.1 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
| compoundn | | | | | | 7.12.7.2 | F.10.4.2 |
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
Expand Down
4 changes: 1 addition & 3 deletions libc/include/llvm-libc-macros/math-macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@
#define FP_LLOGBNAN LONG_MAX
#endif

#ifdef __FAST_MATH__
#if defined(__NVPTX__) || defined(__AMDGPU__) || defined(__FAST_MATH__)
#define math_errhandling 0
#elif defined(__NO_MATH_ERRNO__)
#define math_errhandling (MATH_ERREXCEPT)
#elif defined(__NVPTX__) || defined(__AMDGPU__)
#define math_errhandling (MATH_ERRNO)
#else
#define math_errhandling (MATH_ERRNO | MATH_ERREXCEPT)
#endif
Expand Down
2 changes: 2 additions & 0 deletions libc/spec/stdc.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def StdC : StandardSpec<"stdc"> {
],
[], // Enumerations
[
FunctionSpec<"cbrtf", RetValSpec<FloatType>, [ArgSpec<FloatType>]>,

FunctionSpec<"copysign", RetValSpec<DoubleType>, [ArgSpec<DoubleType>, ArgSpec<DoubleType>]>,
FunctionSpec<"copysignf", RetValSpec<FloatType>, [ArgSpec<FloatType>, ArgSpec<FloatType>]>,
FunctionSpec<"copysignl", RetValSpec<LongDoubleType>, [ArgSpec<LongDoubleType>, ArgSpec<LongDoubleType>]>,
Expand Down
2 changes: 2 additions & 0 deletions libc/src/__support/FPUtil/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ add_header_library(
multiply_add.h
DEPENDS
libc.src.__support.common
FLAGS
FMA_OPT
)

add_header_library(
Expand Down
6 changes: 6 additions & 0 deletions libc/src/__support/FPUtil/FEnvImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ LIBC_INLINE int set_env(const fenv_t *) { return 0; }

namespace LIBC_NAMESPACE::fputil {

LIBC_INLINE int clear_except_if_required(int excepts) {
if (math_errhandling & MATH_ERREXCEPT)
return clear_except(excepts);
return 0;
}

LIBC_INLINE int set_except_if_required(int excepts) {
if (math_errhandling & MATH_ERREXCEPT)
return set_except(excepts);
Expand Down
2 changes: 2 additions & 0 deletions libc/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ add_math_entrypoint_object(canonicalizel)
add_math_entrypoint_object(canonicalizef16)
add_math_entrypoint_object(canonicalizef128)

add_math_entrypoint_object(cbrtf)

add_math_entrypoint_object(ceil)
add_math_entrypoint_object(ceilf)
add_math_entrypoint_object(ceill)
Expand Down
635 changes: 41 additions & 594 deletions libc/src/math/amdgpu/CMakeLists.txt

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions libc/src/math/cbrtf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===-- Implementation header for cbrtf -------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_CBRTF_H
#define LLVM_LIBC_SRC_MATH_CBRTF_H

namespace LIBC_NAMESPACE {

float cbrtf(float x);

} // namespace LIBC_NAMESPACE

#endif // LLVM_LIBC_SRC_MATH_CBRTF_H
16 changes: 16 additions & 0 deletions libc/src/math/generic/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4092,3 +4092,19 @@ add_entrypoint_object(
COMPILE_OPTIONS
-O3
)

add_entrypoint_object(
cbrtf
SRCS
cbrtf.cpp
HDRS
../cbrtf.h
COMPILE_OPTIONS
-O3
DEPENDS
libc.hdr.fenv_macros
libc.src.__support.FPUtil.fenv_impl
libc.src.__support.FPUtil.fp_bits
libc.src.__support.FPUtil.multiply_add
libc.src.__support.macros.optimization
)
157 changes: 157 additions & 0 deletions libc/src/math/generic/cbrtf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
//===-- Implementation of cbrtf function ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/cbrtf.h"
#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/common.h"
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY

namespace LIBC_NAMESPACE {

namespace {

// Look up table for 2^(i/3) for i = 0, 1, 2.
constexpr double CBRT2[3] = {1.0, 0x1.428a2f98d728bp0, 0x1.965fea53d6e3dp0};

// Degree-7 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
// generated by Sollya with:
// > for i from 0 to 15 do {
// P = fpminimax((1 + x)^(1/3) - 1)/x, 6, [|D...|], [i/16, (i + 1)/16]);
// print("{", coeff(P, 0), ",", coeff(P, 1), ",", coeff(P, 2), ",",
// coeff(P, 3), ",", coeff(P, 4), ",", coeff(P, 5), ",",
// coeff(P, 6), "},");
// };
// Then (1 + x)^(1/3) ~ 1 + x * P(x).
constexpr double COEFFS[16][7] = {
{0x1.55555555554ebp-2, -0x1.c71c71c678c0cp-4, 0x1.f9add2776de81p-5,
-0x1.511e10aa964a7p-5, 0x1.ee44165937fa2p-6, -0x1.7c5c9e059345dp-6,
0x1.047f75e0aff14p-6},
{0x1.5555554d1149ap-2, -0x1.c71c676fcb5bp-4, 0x1.f9ab127dc57ebp-5,
-0x1.50ea8fd1d4c15p-5, 0x1.e9d68f28ced43p-6, -0x1.60e0e1e661311p-6,
0x1.716eca1d6e3bcp-7},
{0x1.5555546377d45p-2, -0x1.c71bc1c6d49d2p-4, 0x1.f9924cc0ed24dp-5,
-0x1.4fea3beb53b3bp-5, 0x1.de028a9a07b1bp-6, -0x1.3b090d2233524p-6,
0x1.0aeca34893785p-7},
{0x1.55554dce9f649p-2, -0x1.c7188b34b98f8p-4, 0x1.f93e1af34af49p-5,
-0x1.4d9a06be75c63p-5, 0x1.cb943f4f68992p-6, -0x1.139a685a5e3c4p-6,
0x1.88410674c6a5dp-8},
{0x1.5555347d211c3p-2, -0x1.c70f2a4b1a5fap-4, 0x1.f88420e8602c3p-5,
-0x1.49becfa4ed3ep-5, 0x1.b475cd9013162p-6, -0x1.dcfee1dd2f8efp-7,
0x1.249bb51a1c498p-8},
{0x1.5554f01b33dbap-2, -0x1.c6facb929dbf1p-4, 0x1.f73fb7861252ep-5,
-0x1.4459a4a0071fap-5, 0x1.9a8df2b504fc2p-6, -0x1.9a7ce3006d06ep-7,
0x1.ba9230918fa2ep-9},
{0x1.55545c695db5fp-2, -0x1.c6d6089f20275p-4, 0x1.f556e0ea80efp-5,
-0x1.3d91372d083f4p-5, 0x1.7f66cff331f4p-6, -0x1.606a562491737p-7,
0x1.52e3e17c71069p-9},
{0x1.55534a879232ap-2, -0x1.c69b836998b84p-4, 0x1.f2bb26dac0e4cp-5,
-0x1.359eed43716d7p-5, 0x1.64218cd824fbcp-6, -0x1.2e703e2e091e8p-7,
0x1.0677d9af6aad4p-9},
{0x1.5551836bb5494p-2, -0x1.c64658c15353bp-4, 0x1.ef68517451a6ep-5,
-0x1.2cc20a980dceep-5, 0x1.49843e0fad93ap-6, -0x1.03c59ccb68e54p-7,
0x1.9ad325dc7adcbp-10},
{0x1.554ecacb0d035p-2, -0x1.c5d2664026ffcp-4, 0x1.eb624796ba809p-5,
-0x1.233803d19a535p-5, 0x1.300decb1c3c28p-6, -0x1.befe18031ec3dp-8,
0x1.449f5ee175c69p-10},
{0x1.554ae1f5ae815p-2, -0x1.c53c6b14ff6b2p-4, 0x1.e6b2d5127bb5bp-5,
-0x1.19387336788a3p-5, 0x1.180955a6ab255p-6, -0x1.81696703ba369p-8,
0x1.02cb36389bd79p-10},
{0x1.55458a59f356ep-2, -0x1.c4820dd631ae9p-4, 0x1.e167af818bd15p-5,
-0x1.0ef35f6f72e52p-5, 0x1.019c33b65e4ebp-6, -0x1.4d25bdd52d3a5p-8,
0x1.a008ae91f5936p-11},
{0x1.553e878eafee1p-2, -0x1.c3a1d0b2a3db2p-4, 0x1.db90d8ed9f89bp-5,
-0x1.0490e20f1ae91p-5, 0x1.d9a5d1fc42fe3p-7, -0x1.20bf8227c2abfp-8,
0x1.50f8174cdb6e9p-11},
{0x1.5535a0dedf1b1p-2, -0x1.c29afb8bd01a1p-4, 0x1.d53f6371c1e27p-5,
-0x1.f463209b433e2p-6, 0x1.b35222a17e44p-7, -0x1.f5efbf505e133p-9,
0x1.12e0e94e8586dp-11},
{0x1.552aa25e57bfdp-2, -0x1.c16d811e4acadp-4, 0x1.ce8489b47aa51p-5,
-0x1.dfde7ff758ea8p-6, 0x1.901f43aac38c8p-7, -0x1.b581d07df5ad5p-9,
0x1.c3726535f1fc6p-12},
{0x1.551d5d9b204d3p-2, -0x1.c019e328f8db1p-4, 0x1.c7710f44fc3cep-5,
-0x1.cbbbe25ea8ba4p-6, 0x1.6fe270088623dp-7, -0x1.7e6fc79733761p-9,
0x1.75077abf18d84p-12},
};

} // anonymous namespace

LLVM_LIBC_FUNCTION(float, cbrtf, (float x)) {
using FloatBits = typename fputil::FPBits<float>;
using DoubleBits = typename fputil::FPBits<double>;

FloatBits x_bits(x);

uint32_t x_abs = x_bits.uintval() & 0x7fff'ffff;
uint32_t sign_bit = (x_bits.uintval() >> 31) << DoubleBits::EXP_LEN;

if (LIBC_UNLIKELY(x_abs == 0 || x_abs >= 0x7f80'0000)) {
// x is 0, Inf, or NaN.
return x;
}

double xd = static_cast<double>(x);
DoubleBits xd_bits(xd);

// When using biased exponent of x in double precision,
// x_e = real_exponent_of_x + 1023
// Then:
// x_e / 3 = real_exponent_of_x / 3 + 1023/3
// = real_exponent_of_x / 3 + 341
// So to make it the correct biased exponent of x^(1/3), we add
// 1023 - 341 = 682
// to the quotient x_e / 3.
unsigned x_e = static_cast<unsigned>(xd_bits.get_biased_exponent());
unsigned out_e = (x_e / 3 + 682) | sign_bit;
unsigned shift_e = x_e % 3;

// Set x_m = 2^(x_e % 3) * (1.mantissa)
uint64_t x_m = xd_bits.get_mantissa();
// Use the leading 4 bits for look up table
unsigned idx = static_cast<unsigned>(x_m >> (DoubleBits::FRACTION_LEN - 4));

x_m |= static_cast<uint64_t>(DoubleBits::EXP_BIAS)
<< DoubleBits::FRACTION_LEN;

double x_reduced = DoubleBits(x_m).get_val();
double dx = x_reduced - 1.0;

double dx_sq = dx * dx;
double c0 = fputil::multiply_add(dx, COEFFS[idx][0], 1.0);
double c1 = fputil::multiply_add(dx, COEFFS[idx][2], COEFFS[idx][1]);
double c2 = fputil::multiply_add(dx, COEFFS[idx][4], COEFFS[idx][3]);
double c3 = fputil::multiply_add(dx, COEFFS[idx][6], COEFFS[idx][5]);

double dx_4 = dx_sq * dx_sq;
double p0 = fputil::multiply_add(dx_sq, c1, c0);
double p1 = fputil::multiply_add(dx_sq, c3, c2);

double r = fputil::multiply_add(dx_4, p1, p0) * CBRT2[shift_e];

uint64_t r_m = DoubleBits(r).get_mantissa();
// Check if the output is exact. To be exact, the smallest 1-bit of the
// output has to be at least 2^-7 or higher. So we check the lowest 44 bits
// to see if they are within 2^(-52 + 3) errors from all zeros, then the
// result cube root is exact.
if (LIBC_UNLIKELY(((r_m + 8) & 0xfffffffffff) <= 16)) {
if ((r_m & 0xfffffffffff) <= 8)
r_m &= 0xffff'ffff'ffff'ffe0;
else
r_m = (r_m & 0xffff'ffff'ffff'ffe0) + 0x20;
fputil::clear_except_if_required(FE_INEXACT);
}
// Adjust exponent and sign.
uint64_t r_bits =
r_m | (static_cast<uint64_t>(out_e) << DoubleBits::FRACTION_LEN);

return static_cast<float>(DoubleBits(r_bits).get_val());
}

} // namespace LIBC_NAMESPACE
4 changes: 2 additions & 2 deletions libc/src/math/generic/tan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ LIBC_INLINE DoubleDouble tan_eval(const DoubleDouble &u) {
}

// Accurate evaluation of tan for small u.
Float128 tan_eval(const Float128 &u) {
[[maybe_unused]] Float128 tan_eval(const Float128 &u) {
Float128 u_sq = fputil::quick_mul(u, u);

// tan(x) ~ x + x^3/3 + x^5 * 2/15 + x^7 * 17/315 + x^9 * 62/2835 +
Expand Down Expand Up @@ -127,7 +127,7 @@ Float128 tan_eval(const Float128 &u) {
// Calculation a / b = a * (1/b) for Float128.
// Using the initial approximation of q ~ (1/b), then apply 2 Newton-Raphson
// iterations, before multiplying by a.
Float128 newton_raphson_div(const Float128 &a, Float128 b, double q) {
[[maybe_unused]] Float128 newton_raphson_div(const Float128 &a, Float128 b, double q) {
Float128 q0(q);
constexpr Float128 TWO(2.0);
b.sign = (b.sign == Sign::POS) ? Sign::NEG : Sign::POS;
Expand Down
601 changes: 0 additions & 601 deletions libc/src/math/nvptx/CMakeLists.txt

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion libc/src/math/nvptx/llrint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(long long, llrint, (double x)) { return __nv_llrint(x); }
LLVM_LIBC_FUNCTION(long long, llrint, (double x)) {
return static_cast<long long>(__builtin_rint(x));
}

} // namespace LIBC_NAMESPACE
4 changes: 3 additions & 1 deletion libc/src/math/nvptx/llrintf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(long long, llrintf, (float x)) { return __nv_llrintf(x); }
LLVM_LIBC_FUNCTION(long long, llrintf, (float x)) {
return static_cast<long long>(__builtin_rintf(x));
}

} // namespace LIBC_NAMESPACE
4 changes: 3 additions & 1 deletion libc/src/math/nvptx/lrint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

namespace LIBC_NAMESPACE {

LLVM_LIBC_FUNCTION(long, lrint, (double x)) { return __nv_lrint(x); }
LLVM_LIBC_FUNCTION(long, lrint, (double x)) {
return static_cast<long>(__builtin_rint(x));
}

} // namespace LIBC_NAMESPACE
12 changes: 12 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2213,6 +2213,18 @@ add_fp_unittest(
libc.src.math.f16sqrtl
)

add_fp_unittest(
cbrtf_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
cbrtf_test.cpp
DEPENDS
libc.src.math.cbrtf
libc.src.__support.FPUtil.fp_bits
)

add_subdirectory(generic)
add_subdirectory(smoke)

Expand Down
42 changes: 42 additions & 0 deletions libc/test/src/math/cbrtf_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//===-- Unittests for cbrtf -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "hdr/math_macros.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/math/cbrtf.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

using LlvmLibcCbrtfTest = LIBC_NAMESPACE::testing::FPTest<float>;

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

TEST_F(LlvmLibcCbrtfTest, InFloatRange) {
constexpr uint32_t COUNT = 100'000;
const uint32_t STEP = FPBits(inf).uintval() / COUNT;
for (uint32_t i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
float x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
LIBC_NAMESPACE::cbrtf(x), 0.5);
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, -x,
LIBC_NAMESPACE::cbrtf(-x), 0.5);
}
}

TEST_F(LlvmLibcCbrtfTest, SpecialValues) {
constexpr float INPUTS[] = {
0x1.60451p2f, 0x1.31304cp1f, 0x1.d17cp2f, 0x1.bp-143f, 0x1.338cp2f,
};
for (float v : INPUTS) {
float x = FPBits(v).get_val();
mpfr::ForceRoundingMode r(mpfr::RoundingMode::Upward);
EXPECT_MPFR_MATCH(mpfr::Operation::Cbrt, x, LIBC_NAMESPACE::cbrtf(x), 0.5,
mpfr::RoundingMode::Upward);
}
}
16 changes: 16 additions & 0 deletions libc/test/src/math/exhaustive/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,19 @@ add_fp_unittest(
LINK_LIBRARIES
-lpthread
)

add_fp_unittest(
cbrtf_test
NO_RUN_POSTBUILD
NEED_MPFR
SUITE
libc_math_exhaustive_tests
SRCS
cbrtf_test.cpp
DEPENDS
.exhaustive_test
libc.src.math.cbrtf
libc.src.__support.FPUtil.fp_bits
LINK_LIBRARIES
-lpthread
)
33 changes: 33 additions & 0 deletions libc/test/src/math/exhaustive/cbrtf_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===-- Exhaustive test for cbrtf -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "exhaustive_test.h"
#include "src/math/cbrtf.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

using LlvmLibcCbrtfExhaustiveTest =
LlvmLibcUnaryOpExhaustiveMathTest<float, mpfr::Operation::Cbrt,
LIBC_NAMESPACE::cbrtf>;

// Range: [0, Inf];
static constexpr uint32_t POS_START = 0x0000'0000U;
static constexpr uint32_t POS_STOP = 0x7f80'0000U;

TEST_F(LlvmLibcCbrtfExhaustiveTest, PostiveRange) {
test_full_range_all_roundings(POS_START, POS_STOP);
}

// Range: [-Inf, 0];
static constexpr uint32_t NEG_START = 0x8000'0000U;
static constexpr uint32_t NEG_STOP = 0xff80'0000U;

TEST_F(LlvmLibcCbrtfExhaustiveTest, NegativeRange) {
test_full_range_all_roundings(NEG_START, NEG_STOP);
}
15 changes: 8 additions & 7 deletions libc/test/src/math/exhaustive/exhaustive_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ struct UnaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
TEST_MPFR_MATCH_ROUNDING_SILENTLY(Op, x, Func(x), 0.5, rounding);
failed += (!correct);
// Uncomment to print out failed values.
// if (!correct) {
// EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
// }
if (!correct) {
EXPECT_MPFR_MATCH_ROUNDING(Op, x, Func(x), 0.5, rounding);
}
} while (bits++ < stop);
return failed;
}
Expand Down Expand Up @@ -97,9 +97,9 @@ struct BinaryOpChecker : public virtual LIBC_NAMESPACE::testing::Test {
0.5, rounding);
failed += (!correct);
// Uncomment to print out failed values.
// if (!correct) {
// EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
// }
if (!correct) {
EXPECT_MPFR_MATCH_ROUNDING(Op, input, Func(x, y), 0.5, rounding);
}
} while (ybits++ < y_stop);
} while (xbits++ < x_stop);
return failed;
Expand Down Expand Up @@ -187,7 +187,8 @@ struct LlvmLibcExhaustiveMathTest
std::stringstream msg;
msg << "Test failed for " << std::dec << failed_in_range
<< " inputs in range: ";
explain_failed_range(msg, start, stop, extra_range_bounds...);
explain_failed_range(msg, range_begin, range_end,
extra_range_bounds...);
msg << "\n";
std::cerr << msg.str() << std::flush;

Expand Down
10 changes: 10 additions & 0 deletions libc/test/src/math/smoke/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3961,3 +3961,13 @@ add_fp_unittest(
DEPENDS
libc.src.math.tan
)

add_fp_unittest(
cbrtf_test
SUITE
libc-math-smoke-tests
SRCS
cbrtf_test.cpp
DEPENDS
libc.src.math.cbrtf
)
33 changes: 33 additions & 0 deletions libc/test/src/math/smoke/cbrtf_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===-- Unittests for cbrtf -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/cbrtf.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"

using LlvmLibcCbrtfTest = LIBC_NAMESPACE::testing::FPTest<float>;

using LIBC_NAMESPACE::testing::tlog;

TEST_F(LlvmLibcCbrtfTest, SpecialNumbers) {
EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::cbrtf(aNaN));
EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::cbrtf(inf));
EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, LIBC_NAMESPACE::cbrtf(neg_inf));
EXPECT_FP_EQ_ALL_ROUNDING(zero, LIBC_NAMESPACE::cbrtf(zero));
EXPECT_FP_EQ_ALL_ROUNDING(neg_zero, LIBC_NAMESPACE::cbrtf(neg_zero));
EXPECT_FP_EQ_ALL_ROUNDING(1.0f, LIBC_NAMESPACE::cbrtf(1.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-1.0f, LIBC_NAMESPACE::cbrtf(-1.0f));
EXPECT_FP_EQ_ALL_ROUNDING(2.0f, LIBC_NAMESPACE::cbrtf(8.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-2.0f, LIBC_NAMESPACE::cbrtf(-8.0f));
EXPECT_FP_EQ_ALL_ROUNDING(3.0f, LIBC_NAMESPACE::cbrtf(27.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-3.0f, LIBC_NAMESPACE::cbrtf(-27.0f));
EXPECT_FP_EQ_ALL_ROUNDING(5.0f, LIBC_NAMESPACE::cbrtf(125.0f));
EXPECT_FP_EQ_ALL_ROUNDING(-5.0f, LIBC_NAMESPACE::cbrtf(-125.0f));
EXPECT_FP_EQ_ALL_ROUNDING(0x1.0p42f, LIBC_NAMESPACE::cbrtf(0x1.0p126f));
EXPECT_FP_EQ_ALL_ROUNDING(-0x1.0p42f, LIBC_NAMESPACE::cbrtf(-0x1.0p126f));
}
8 changes: 8 additions & 0 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ class MPFRNumber {
return result;
}

MPFRNumber cbrt() const {
MPFRNumber result(*this);
mpfr_cbrt(result.value, value, mpfr_rounding);
return result;
}

MPFRNumber ceil() const {
MPFRNumber result(*this);
mpfr_ceil(result.value, value);
Expand Down Expand Up @@ -702,6 +708,8 @@ unary_operation(Operation op, InputType input, unsigned int precision,
return mpfrInput.atan();
case Operation::Atanh:
return mpfrInput.atanh();
case Operation::Cbrt:
return mpfrInput.cbrt();
case Operation::Ceil:
return mpfrInput.ceil();
case Operation::Cos:
Expand Down
1 change: 1 addition & 0 deletions libc/utils/MPFRWrapper/MPFRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ enum class Operation : int {
Asinh,
Atan,
Atanh,
Cbrt,
Ceil,
Cos,
Cosh,
Expand Down
174 changes: 56 additions & 118 deletions lldb/source/Plugins/Language/CPlusPlus/LibCxxMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "lldb/Utility/Endian.h"
#include "lldb/Utility/Status.h"
#include "lldb/Utility/Stream.h"
#include "lldb/lldb-enumerations.h"
#include "lldb/lldb-forward.h"

using namespace lldb;
Expand Down Expand Up @@ -223,11 +224,10 @@ class LibCxxMapIteratorSyntheticFrontEnd : public SyntheticChildrenFrontEnd {

size_t GetIndexOfChildWithName(ConstString name) override;

~LibCxxMapIteratorSyntheticFrontEnd() override;
~LibCxxMapIteratorSyntheticFrontEnd() override = default;

private:
ValueObject *m_pair_ptr;
lldb::ValueObjectSP m_pair_sp;
ValueObjectSP m_pair_sp = nullptr;
};
} // namespace formatters
} // namespace lldb_private
Expand Down Expand Up @@ -464,125 +464,71 @@ lldb_private::formatters::LibcxxStdMapSyntheticFrontEndCreator(

lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::
LibCxxMapIteratorSyntheticFrontEnd(lldb::ValueObjectSP valobj_sp)
: SyntheticChildrenFrontEnd(*valobj_sp), m_pair_ptr(), m_pair_sp() {
: SyntheticChildrenFrontEnd(*valobj_sp) {
if (valobj_sp)
Update();
}

lldb::ChildCacheState
lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::Update() {
m_pair_sp.reset();
m_pair_ptr = nullptr;

ValueObjectSP valobj_sp = m_backend.GetSP();
if (!valobj_sp)
return lldb::ChildCacheState::eRefetch;

TargetSP target_sp(valobj_sp->GetTargetSP());

if (!target_sp)
return lldb::ChildCacheState::eRefetch;

// this must be a ValueObject* because it is a child of the ValueObject we
// are producing children for it if were a ValueObjectSP, we would end up
// with a loop (iterator -> synthetic -> child -> parent == iterator) and
// that would in turn leak memory by never allowing the ValueObjects to die
// and free their memory
m_pair_ptr = valobj_sp
->GetValueForExpressionPath(
".__i_.__ptr_->__value_", nullptr, nullptr,
ValueObject::GetValueForExpressionPathOptions()
.DontCheckDotVsArrowSyntax()
.SetSyntheticChildrenTraversal(
ValueObject::GetValueForExpressionPathOptions::
SyntheticChildrenTraversal::None),
nullptr)
.get();

if (!m_pair_ptr) {
m_pair_ptr = valobj_sp
->GetValueForExpressionPath(
".__i_.__ptr_", nullptr, nullptr,
ValueObject::GetValueForExpressionPathOptions()
.DontCheckDotVsArrowSyntax()
.SetSyntheticChildrenTraversal(
ValueObject::GetValueForExpressionPathOptions::
SyntheticChildrenTraversal::None),
nullptr)
.get();
if (m_pair_ptr) {
auto __i_(valobj_sp->GetChildMemberWithName("__i_"));
if (!__i_) {
m_pair_ptr = nullptr;
return lldb::ChildCacheState::eRefetch;
}
CompilerType pair_type(
__i_->GetCompilerType().GetTypeTemplateArgument(0));
std::string name;
uint64_t bit_offset_ptr;
uint32_t bitfield_bit_size_ptr;
bool is_bitfield_ptr;
pair_type = pair_type.GetFieldAtIndex(
0, name, &bit_offset_ptr, &bitfield_bit_size_ptr, &is_bitfield_ptr);
if (!pair_type) {
m_pair_ptr = nullptr;
return lldb::ChildCacheState::eRefetch;
}
// m_backend is a std::map::iterator
// ...which is a __map_iterator<__tree_iterator<..., __node_pointer, ...>>
//
// Then, __map_iterator::__i_ is a __tree_iterator
auto tree_iter_sp = valobj_sp->GetChildMemberWithName("__i_");
if (!tree_iter_sp)
return lldb::ChildCacheState::eRefetch;

auto addr(m_pair_ptr->GetValueAsUnsigned(LLDB_INVALID_ADDRESS));
m_pair_ptr = nullptr;
if (addr && addr != LLDB_INVALID_ADDRESS) {
auto ts = pair_type.GetTypeSystem();
auto ast_ctx = ts.dyn_cast_or_null<TypeSystemClang>();
if (!ast_ctx)
return lldb::ChildCacheState::eRefetch;

// Mimick layout of std::__tree_iterator::__ptr_ and read it in
// from process memory.
//
// The following shows the contiguous block of memory:
//
// +-----------------------------+ class __tree_end_node
// __ptr_ | pointer __left_; |
// +-----------------------------+ class __tree_node_base
// | pointer __right_; |
// | __parent_pointer __parent_; |
// | bool __is_black_; |
// +-----------------------------+ class __tree_node
// | __node_value_type __value_; | <<< our key/value pair
// +-----------------------------+
//
CompilerType tree_node_type = ast_ctx->CreateStructForIdentifier(
llvm::StringRef(),
{{"ptr0",
ast_ctx->GetBasicType(lldb::eBasicTypeVoid).GetPointerType()},
{"ptr1",
ast_ctx->GetBasicType(lldb::eBasicTypeVoid).GetPointerType()},
{"ptr2",
ast_ctx->GetBasicType(lldb::eBasicTypeVoid).GetPointerType()},
{"cw", ast_ctx->GetBasicType(lldb::eBasicTypeBool)},
{"payload", pair_type}});
std::optional<uint64_t> size = tree_node_type.GetByteSize(nullptr);
if (!size)
return lldb::ChildCacheState::eRefetch;
WritableDataBufferSP buffer_sp(new DataBufferHeap(*size, 0));
ProcessSP process_sp(target_sp->GetProcessSP());
Status error;
process_sp->ReadMemory(addr, buffer_sp->GetBytes(),
buffer_sp->GetByteSize(), error);
if (error.Fail())
return lldb::ChildCacheState::eRefetch;
DataExtractor extractor(buffer_sp, process_sp->GetByteOrder(),
process_sp->GetAddressByteSize());
auto pair_sp = CreateValueObjectFromData(
"pair", extractor, valobj_sp->GetExecutionContextRef(),
tree_node_type);
if (pair_sp)
m_pair_sp = pair_sp->GetChildAtIndex(4);
}
}
// Type is __tree_iterator::__node_pointer
// (We could alternatively also get this from the template argument)
auto node_pointer_type =
tree_iter_sp->GetCompilerType().GetDirectNestedTypeWithName(
"__node_pointer");
if (!node_pointer_type.IsValid())
return lldb::ChildCacheState::eRefetch;

// __ptr_ is a __tree_iterator::__iter_pointer
auto iter_pointer_sp = tree_iter_sp->GetChildMemberWithName("__ptr_");
if (!iter_pointer_sp)
return lldb::ChildCacheState::eRefetch;

// Cast the __iter_pointer to a __node_pointer (which stores our key/value
// pair)
auto node_pointer_sp = iter_pointer_sp->Cast(node_pointer_type);
if (!node_pointer_sp)
return lldb::ChildCacheState::eRefetch;

auto key_value_sp = node_pointer_sp->GetChildMemberWithName("__value_");
if (!key_value_sp)
return lldb::ChildCacheState::eRefetch;

// Create the synthetic child, which is a pair where the key and value can be
// retrieved by querying the synthetic frontend for
// GetIndexOfChildWithName("first") and GetIndexOfChildWithName("second")
// respectively.
//
// std::map stores the actual key/value pair in value_type::__cc_ (or
// previously __cc).
key_value_sp = key_value_sp->Clone(ConstString("pair"));
if (key_value_sp->GetNumChildrenIgnoringErrors() == 1) {
auto child0_sp = key_value_sp->GetChildAtIndex(0);
if (child0_sp &&
(child0_sp->GetName() == "__cc_" || child0_sp->GetName() == "__cc"))
key_value_sp = child0_sp->Clone(ConstString("pair"));
}

m_pair_sp = key_value_sp;

return lldb::ChildCacheState::eRefetch;
}

Expand All @@ -594,11 +540,10 @@ llvm::Expected<uint32_t> lldb_private::formatters::
lldb::ValueObjectSP
lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::GetChildAtIndex(
uint32_t idx) {
if (m_pair_ptr)
return m_pair_ptr->GetChildAtIndex(idx);
if (m_pair_sp)
return m_pair_sp->GetChildAtIndex(idx);
return lldb::ValueObjectSP();
if (!m_pair_sp)
return nullptr;

return m_pair_sp->GetChildAtIndex(idx);
}

bool lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::
Expand All @@ -608,17 +553,10 @@ bool lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::

size_t lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::
GetIndexOfChildWithName(ConstString name) {
if (name == "first")
return 0;
if (name == "second")
return 1;
return UINT32_MAX;
}
if (!m_pair_sp)
return UINT32_MAX;

lldb_private::formatters::LibCxxMapIteratorSyntheticFrontEnd::
~LibCxxMapIteratorSyntheticFrontEnd() {
// this will be deleted when its parent dies (since it's a child object)
// delete m_pair_ptr;
return m_pair_sp->GetIndexOfChildWithName(name);
}

SyntheticChildrenFrontEnd *
Expand Down
156 changes: 50 additions & 106 deletions lldb/source/Plugins/Language/CPlusPlus/LibCxxUnorderedMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,6 @@ class LibcxxStdUnorderedMapSyntheticFrontEnd
std::vector<std::pair<ValueObject *, uint64_t>> m_elements_cache;
};

/// Formats libcxx's std::unordered_map iterators
///
/// In raw form a std::unordered_map::iterator is represented as follows:
///
/// (lldb) var it --raw --ptr-depth 1
/// (std::__1::__hash_map_iterator<
/// std::__1::__hash_iterator<
/// std::__1::__hash_node<
/// std::__1::__hash_value_type<
/// std::__1::basic_string<char, std::__1::char_traits<char>,
/// std::__1::allocator<char> >, std::__1::basic_string<char,
/// std::__1::char_traits<char>, std::__1::allocator<char> > >,
/// void *> *> >)
/// it = {
/// __i_ = {
/// __node_ = 0x0000600001700040 {
/// __next_ = 0x0000600001704000
/// }
/// }
/// }
class LibCxxUnorderedMapIteratorSyntheticFrontEnd
: public SyntheticChildrenFrontEnd {
public:
Expand All @@ -90,9 +70,6 @@ class LibCxxUnorderedMapIteratorSyntheticFrontEnd
size_t GetIndexOfChildWithName(ConstString name) override;

private:
ValueObject *m_iter_ptr = nullptr; ///< Held, not owned. Child of iterator
///< ValueObject supplied at construction.

lldb::ValueObjectSP m_pair_sp; ///< ValueObject for the key/value pair
///< that the iterator currently points
///< to.
Expand Down Expand Up @@ -304,7 +281,6 @@ lldb_private::formatters::LibCxxUnorderedMapIteratorSyntheticFrontEnd::
lldb::ChildCacheState lldb_private::formatters::
LibCxxUnorderedMapIteratorSyntheticFrontEnd::Update() {
m_pair_sp.reset();
m_iter_ptr = nullptr;

ValueObjectSP valobj_sp = m_backend.GetSP();
if (!valobj_sp)
Expand All @@ -315,98 +291,66 @@ lldb::ChildCacheState lldb_private::formatters::
if (!target_sp)
return lldb::ChildCacheState::eRefetch;

if (!valobj_sp)
// Get the unordered_map::iterator
// m_backend is an 'unordered_map::iterator', aka a
// '__hash_map_iterator<__hash_table::iterator>'
//
// __hash_map_iterator::__i_ is a __hash_table::iterator (aka
// __hash_iterator<__node_pointer>)
auto hash_iter_sp = valobj_sp->GetChildMemberWithName("__i_");
if (!hash_iter_sp)
return lldb::ChildCacheState::eRefetch;

auto exprPathOptions = ValueObject::GetValueForExpressionPathOptions()
.DontCheckDotVsArrowSyntax()
.SetSyntheticChildrenTraversal(
ValueObject::GetValueForExpressionPathOptions::
SyntheticChildrenTraversal::None);

// This must be a ValueObject* because it is a child of the ValueObject we
// are producing children for it if were a ValueObjectSP, we would end up
// with a loop (iterator -> synthetic -> child -> parent == iterator) and
// that would in turn leak memory by never allowing the ValueObjects to die
// and free their memory.
m_iter_ptr =
valobj_sp
->GetValueForExpressionPath(".__i_.__node_", nullptr, nullptr,
exprPathOptions, nullptr)
.get();

if (m_iter_ptr) {
auto iter_child(valobj_sp->GetChildMemberWithName("__i_"));
if (!iter_child) {
m_iter_ptr = nullptr;
return lldb::ChildCacheState::eRefetch;
}

CompilerType node_type(iter_child->GetCompilerType()
.GetTypeTemplateArgument(0)
.GetPointeeType());

CompilerType pair_type(node_type.GetTypeTemplateArgument(0));

std::string name;
uint64_t bit_offset_ptr;
uint32_t bitfield_bit_size_ptr;
bool is_bitfield_ptr;

pair_type = pair_type.GetFieldAtIndex(
0, name, &bit_offset_ptr, &bitfield_bit_size_ptr, &is_bitfield_ptr);
if (!pair_type) {
m_iter_ptr = nullptr;
return lldb::ChildCacheState::eRefetch;
}
// Type is '__hash_iterator<__node_pointer>'
auto hash_iter_type = hash_iter_sp->GetCompilerType();
if (!hash_iter_type.IsValid())
return lldb::ChildCacheState::eRefetch;

uint64_t addr = m_iter_ptr->GetValueAsUnsigned(LLDB_INVALID_ADDRESS);
m_iter_ptr = nullptr;
// Type is '__node_pointer'
auto node_pointer_type = hash_iter_type.GetTypeTemplateArgument(0);
if (!node_pointer_type.IsValid())
return lldb::ChildCacheState::eRefetch;

if (addr == 0 || addr == LLDB_INVALID_ADDRESS)
return lldb::ChildCacheState::eRefetch;
// Cast the __hash_iterator to a __node_pointer (which stores our key/value
// pair)
auto hash_node_sp = hash_iter_sp->Cast(node_pointer_type);
if (!hash_node_sp)
return lldb::ChildCacheState::eRefetch;

auto ts = pair_type.GetTypeSystem();
auto ast_ctx = ts.dyn_cast_or_null<TypeSystemClang>();
if (!ast_ctx)
auto key_value_sp = hash_node_sp->GetChildMemberWithName("__value_");
if (!key_value_sp) {
// clang-format off
// Since D101206 (ba79fb2e1f), libc++ wraps the `__value_` in an
// anonymous union.
// Child 0: __hash_node_base base class
// Child 1: __hash_
// Child 2: anonymous union
// clang-format on
auto anon_union_sp = hash_node_sp->GetChildAtIndex(2);
if (!anon_union_sp)
return lldb::ChildCacheState::eRefetch;

// Mimick layout of std::__hash_iterator::__node_ and read it in
// from process memory.
//
// The following shows the contiguous block of memory:
//
// +-----------------------------+ class __hash_node_base
// __node_ | __next_pointer __next_; |
// +-----------------------------+ class __hash_node
// | size_t __hash_; |
// | __node_value_type __value_; | <<< our key/value pair
// +-----------------------------+
//
CompilerType tree_node_type = ast_ctx->CreateStructForIdentifier(
llvm::StringRef(),
{{"__next_",
ast_ctx->GetBasicType(lldb::eBasicTypeVoid).GetPointerType()},
{"__hash_", ast_ctx->GetBasicType(lldb::eBasicTypeUnsignedLongLong)},
{"__value_", pair_type}});
std::optional<uint64_t> size = tree_node_type.GetByteSize(nullptr);
if (!size)
return lldb::ChildCacheState::eRefetch;
WritableDataBufferSP buffer_sp(new DataBufferHeap(*size, 0));
ProcessSP process_sp(target_sp->GetProcessSP());
Status error;
process_sp->ReadMemory(addr, buffer_sp->GetBytes(),
buffer_sp->GetByteSize(), error);
if (error.Fail())
key_value_sp = anon_union_sp->GetChildMemberWithName("__value_");
if (!key_value_sp)
return lldb::ChildCacheState::eRefetch;
DataExtractor extractor(buffer_sp, process_sp->GetByteOrder(),
process_sp->GetAddressByteSize());
auto pair_sp = CreateValueObjectFromData(
"pair", extractor, valobj_sp->GetExecutionContextRef(), tree_node_type);
if (pair_sp)
m_pair_sp = pair_sp->GetChildAtIndex(2);
}

// Create the synthetic child, which is a pair where the key and value can be
// retrieved by querying the synthetic frontend for
// GetIndexOfChildWithName("first") and GetIndexOfChildWithName("second")
// respectively.
//
// std::unordered_map stores the actual key/value pair in
// __hash_value_type::__cc_ (or previously __cc).
auto potential_child_sp = key_value_sp->Clone(ConstString("pair"));
if (potential_child_sp)
if (potential_child_sp->GetNumChildrenIgnoringErrors() == 1)
if (auto child0_sp = potential_child_sp->GetChildAtIndex(0);
child0_sp->GetName() == "__cc_" || child0_sp->GetName() == "__cc")
potential_child_sp = child0_sp->Clone(ConstString("pair"));

m_pair_sp = potential_child_sp;

return lldb::ChildCacheState::eRefetch;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,34 @@ def cleanup():
self.expect("frame variable iimI", substrs=["first = 43981", "second = 61681"])
self.expect("expr iimI", substrs=["first = 43981", "second = 61681"])

self.expect("frame variable iimI.first", substrs=["first = 43981"])
self.expect("frame variable iimI.first", substrs=["second"], matching=False)
self.expect("frame variable iimI.second", substrs=["second = 61681"])
self.expect("frame variable iimI.second", substrs=["first"], matching=False)

self.expect("frame variable simI", substrs=['first = "world"', "second = 42"])
self.expect("expr simI", substrs=['first = "world"', "second = 42"])

self.expect("frame variable simI.first", substrs=['first = "world"'])
self.expect("frame variable simI.first", substrs=["second"], matching=False)
self.expect("frame variable simI.second", substrs=["second = 42"])
self.expect("frame variable simI.second", substrs=["first"], matching=False)

self.expect("frame variable svI", substrs=['item = "hello"'])
self.expect("expr svI", substrs=['item = "hello"'])

self.expect("frame variable iiumI", substrs=["first = 61453", "second = 51966"])
self.expect("expr iiumI", substrs=["first = 61453", "second = 51966"])

self.expect("frame variable siumI", substrs=['first = "hello"', "second = 137"])
self.expect("expr siumI", substrs=['first = "hello"', "second = 137"])

self.expect("frame variable iiumI.first", substrs=["first = 61453"])
self.expect("frame variable iiumI.first", substrs=["second"], matching=False)
self.expect("frame variable iiumI.second", substrs=["second = 51966"])
self.expect("frame variable iiumI.second", substrs=["first"], matching=False)

self.expect("frame variable siumI.first", substrs=['first = "hello"'])
self.expect("frame variable siumI.first", substrs=["second"], matching=False)
self.expect("frame variable siumI.second", substrs=["second = 137"])
self.expect("frame variable siumI.second", substrs=["first"], matching=False)
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
#include <string>
#include <map>
#include <string>
#include <vector>

typedef std::map<int, int> intint_map;
typedef std::map<std::string, int> strint_map;

typedef std::unordered_map<int, int> intint_umap;
typedef std::unordered_map<std::string, int> strint_umap;

typedef std::vector<int> int_vector;
typedef std::vector<std::string> string_vector;

typedef intint_map::iterator iimter;
typedef strint_map::iterator simter;
typedef intint_map::iterator ii_map_iter;
typedef strint_map::iterator si_map_iter;
typedef intint_umap::iterator ii_umap_iter;
typedef strint_umap::iterator si_umap_iter;

typedef int_vector::iterator ivter;
typedef string_vector::iterator svter;

int main()
{
intint_map iim;
iim[0xABCD] = 0xF0F1;
int main() {
intint_map iim;
iim[0xABCD] = 0xF0F1;

strint_map sim;
sim["world"] = 42;

intint_umap iium;
iium[0xF00D] = 0xCAFE;

strint_map sim;
sim["world"] = 42;
strint_umap sium;
sium["hello"] = 137;

int_vector iv;
iv.push_back(3);
int_vector iv;
iv.push_back(3);

string_vector sv;
sv.push_back("hello");
string_vector sv;
sv.push_back("hello");

iimter iimI = iim.begin();
simter simI = sim.begin();
ii_map_iter iimI = iim.begin();
si_map_iter simI = sim.begin();
ii_umap_iter iiumI = iium.begin();
si_umap_iter siumI = sium.begin();

ivter ivI = iv.begin();
svter svI = sv.begin();
ivter ivI = iv.begin();
svter svI = sv.begin();

return 0; // Set break point at this line.
return 0; // Set break point at this line.
}
2 changes: 1 addition & 1 deletion llvm/include/llvm/Analysis/ValueLattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class ValueLatticeElement {
return std::nullopt;
}

ConstantRange asConstantRange(Type *Ty, bool UndefAllowed = false) {
ConstantRange asConstantRange(Type *Ty, bool UndefAllowed = false) const {
assert(Ty->isIntOrIntVectorTy() && "Must be integer type");
if (isConstantRange(UndefAllowed))
return getConstantRange();
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AMDGPU/MIMGInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// - MIMGEncGfx10Default: gfx10 default (non-NSA) encoding
// - MIMGEncGfx10NSA: gfx10 NSA encoding
// - MIMGEncGfx11Default: gfx11 default (non-NSA) encoding
// - MIMGEncGfx11NSA: gfx11 NSA encoding
// - MIMGEncGfx11NSA: gfx11 partial NSA encoding
// - MIMGEncGfx12: gfx12 encoding (partial NSA)
class MIMGEncoding;

def MIMGEncGfx6 : MIMGEncoding;
Expand Down
291 changes: 255 additions & 36 deletions llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,19 @@ class LoopIdiomRecognize {
bool recognizePopcount();
void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst,
PHINode *CntPhi, Value *Var);
bool isProfitableToInsertFFS(Intrinsic::ID IntrinID, Value *InitX,
bool ZeroCheck, size_t CanonicalSize);
bool insertFFSIfProfitable(Intrinsic::ID IntrinID, Value *InitX,
Instruction *DefX, PHINode *CntPhi,
Instruction *CntInst);
bool recognizeAndInsertFFS(); /// Find First Set: ctlz or cttz
bool recognizeShiftUntilLessThan();
void transformLoopToCountable(Intrinsic::ID IntrinID, BasicBlock *PreCondBB,
Instruction *CntInst, PHINode *CntPhi,
Value *Var, Instruction *DefX,
const DebugLoc &DL, bool ZeroCheck,
bool IsCntPhiUsedOutsideLoop);
bool IsCntPhiUsedOutsideLoop,
bool InsertSub = false);

bool recognizeShiftUntilBitTest();
bool recognizeShiftUntilZero();
Expand Down Expand Up @@ -1482,7 +1489,8 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() {
<< CurLoop->getHeader()->getName() << "\n");

return recognizePopcount() || recognizeAndInsertFFS() ||
recognizeShiftUntilBitTest() || recognizeShiftUntilZero();
recognizeShiftUntilBitTest() || recognizeShiftUntilZero() ||
recognizeShiftUntilLessThan();
}

/// Check if the given conditional branch is based on the comparison between
Expand Down Expand Up @@ -1517,6 +1525,34 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry,
return nullptr;
}

/// Check if the given conditional branch is based on an unsigned less-than
/// comparison between a variable and a constant, and if the comparison is false
/// the control yields to the loop entry. If the branch matches the behaviour,
/// the variable involved in the comparison is returned.
static Value *matchShiftULTCondition(BranchInst *BI, BasicBlock *LoopEntry,
uint64_t &Threshold) {
if (!BI || !BI->isConditional())
return nullptr;

ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());
if (!Cond)
return nullptr;

ConstantInt *CmpConst = dyn_cast<ConstantInt>(Cond->getOperand(1));
if (!CmpConst)
return nullptr;

BasicBlock *FalseSucc = BI->getSuccessor(1);
ICmpInst::Predicate Pred = Cond->getPredicate();

if (Pred == ICmpInst::ICMP_ULT && FalseSucc == LoopEntry) {
Threshold = CmpConst->getZExtValue();
return Cond->getOperand(0);
}

return nullptr;
}

// Check if the recurrence variable `VarX` is in the right form to create
// the idiom. Returns the value coerced to a PHINode if so.
static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX,
Expand All @@ -1528,6 +1564,107 @@ static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX,
return nullptr;
}

/// Return true if the idiom is detected in the loop.
///
/// Additionally:
/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ)
/// or nullptr if there is no such.
/// 2) \p CntPhi is set to the corresponding phi node
/// or nullptr if there is no such.
/// 3) \p InitX is set to the value whose CTLZ could be used.
/// 4) \p DefX is set to the instruction calculating Loop exit condition.
/// 5) \p Threshold is set to the constant involved in the unsigned less-than
/// comparison.
///
/// The core idiom we are trying to detect is:
/// \code
/// if (x0 < 2)
/// goto loop-exit // the precondition of the loop
/// cnt0 = init-val
/// do {
/// x = phi (x0, x.next); //PhiX
/// cnt = phi (cnt0, cnt.next)
///
/// cnt.next = cnt + 1;
/// ...
/// x.next = x >> 1; // DefX
/// } while (x >= 4)
/// loop-exit:
/// \endcode
static bool detectShiftUntilLessThanIdiom(Loop *CurLoop, const DataLayout &DL,
Intrinsic::ID &IntrinID,
Value *&InitX, Instruction *&CntInst,
PHINode *&CntPhi, Instruction *&DefX,
uint64_t &Threshold) {
BasicBlock *LoopEntry;

DefX = nullptr;
CntInst = nullptr;
CntPhi = nullptr;
LoopEntry = *(CurLoop->block_begin());

// step 1: Check if the loop-back branch is in desirable form.
if (Value *T = matchShiftULTCondition(
dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry,
Threshold))
DefX = dyn_cast<Instruction>(T);
else
return false;

// step 2: Check the recurrence of variable X
if (!DefX || !isa<PHINode>(DefX))
return false;

PHINode *VarPhi = cast<PHINode>(DefX);
int Idx = VarPhi->getBasicBlockIndex(LoopEntry);
if (Idx == -1)
return false;

DefX = dyn_cast<Instruction>(VarPhi->getIncomingValue(Idx));
if (!DefX || DefX->getNumOperands() == 0 || DefX->getOperand(0) != VarPhi)
return false;

// step 3: detect instructions corresponding to "x.next = x >> 1"
if (DefX->getOpcode() != Instruction::LShr)
return false;

IntrinID = Intrinsic::ctlz;
ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1));
if (!Shft || !Shft->isOne())
return false;

InitX = VarPhi->getIncomingValueForBlock(CurLoop->getLoopPreheader());

// step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1
// or cnt.next = cnt + -1.
// TODO: We can skip the step. If loop trip count is known (CTLZ),
// then all uses of "cnt.next" could be optimized to the trip count
// plus "cnt0". Currently it is not optimized.
// This step could be used to detect POPCNT instruction:
// cnt.next = cnt + (x.next & 1)
for (Instruction &Inst : llvm::make_range(
LoopEntry->getFirstNonPHI()->getIterator(), LoopEntry->end())) {
if (Inst.getOpcode() != Instruction::Add)
continue;

ConstantInt *Inc = dyn_cast<ConstantInt>(Inst.getOperand(1));
if (!Inc || (!Inc->isOne() && !Inc->isMinusOne()))
continue;

PHINode *Phi = getRecurrenceVar(Inst.getOperand(0), &Inst, LoopEntry);
if (!Phi)
continue;

CntInst = &Inst;
CntPhi = Phi;
break;
}
if (!CntInst)
return false;

return true;
}

/// Return true iff the idiom is detected in the loop.
///
/// Additionally:
Expand Down Expand Up @@ -1756,27 +1893,35 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL,
return true;
}

/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop
/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new
/// trip count returns true; otherwise, returns false.
bool LoopIdiomRecognize::recognizeAndInsertFFS() {
// Give up if the loop has multiple blocks or multiple backedges.
if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
return false;
// Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always
// profitable if we delete the loop.
bool LoopIdiomRecognize::isProfitableToInsertFFS(Intrinsic::ID IntrinID,
Value *InitX, bool ZeroCheck,
size_t CanonicalSize) {
const Value *Args[] = {InitX,
ConstantInt::getBool(InitX->getContext(), ZeroCheck)};

Intrinsic::ID IntrinID;
Value *InitX;
Instruction *DefX = nullptr;
PHINode *CntPhi = nullptr;
Instruction *CntInst = nullptr;
// Help decide if transformation is profitable. For ShiftUntilZero idiom,
// this is always 6.
size_t IdiomCanonicalSize = 6;
// @llvm.dbg doesn't count as they have no semantic effect.
auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
uint32_t HeaderSize =
std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());

if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX,
CntInst, CntPhi, DefX))
IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args);
InstructionCost Cost = TTI->getIntrinsicInstrCost(
Attrs, TargetTransformInfo::TCK_SizeAndLatency);
if (HeaderSize != CanonicalSize && Cost > TargetTransformInfo::TCC_Basic)
return false;

return true;
}

/// Convert CTLZ / CTTZ idiom loop into countable loop.
/// If CTLZ / CTTZ inserted as a new trip count returns true; otherwise,
/// returns false.
bool LoopIdiomRecognize::insertFFSIfProfitable(Intrinsic::ID IntrinID,
Value *InitX, Instruction *DefX,
PHINode *CntPhi,
Instruction *CntInst) {
bool IsCntPhiUsedOutsideLoop = false;
for (User *U : CntPhi->users())
if (!CurLoop->contains(cast<Instruction>(U))) {
Expand Down Expand Up @@ -1818,35 +1963,107 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() {
ZeroCheck = true;
}

// Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always
// profitable if we delete the loop.

// the loop has only 6 instructions:
// FFS idiom loop has only 6 instructions:
// %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ]
// %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ]
// %shr = ashr %n.addr.0, 1
// %tobool = icmp eq %shr, 0
// %inc = add nsw %i.0, 1
// br i1 %tobool
size_t IdiomCanonicalSize = 6;
if (!isProfitableToInsertFFS(IntrinID, InitX, ZeroCheck, IdiomCanonicalSize))
return false;

const Value *Args[] = {InitX,
ConstantInt::getBool(InitX->getContext(), ZeroCheck)};
transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX,
DefX->getDebugLoc(), ZeroCheck,
IsCntPhiUsedOutsideLoop);
return true;
}

// @llvm.dbg doesn't count as they have no semantic effect.
auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();
uint32_t HeaderSize =
std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end());
/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop
/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new
/// trip count returns true; otherwise, returns false.
bool LoopIdiomRecognize::recognizeAndInsertFFS() {
// Give up if the loop has multiple blocks or multiple backedges.
if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
return false;

IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args);
InstructionCost Cost =
TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency);
if (HeaderSize != IdiomCanonicalSize &&
Cost > TargetTransformInfo::TCC_Basic)
Intrinsic::ID IntrinID;
Value *InitX;
Instruction *DefX = nullptr;
PHINode *CntPhi = nullptr;
Instruction *CntInst = nullptr;

if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX, CntInst, CntPhi,
DefX))
return false;

return insertFFSIfProfitable(IntrinID, InitX, DefX, CntPhi, CntInst);
}

bool LoopIdiomRecognize::recognizeShiftUntilLessThan() {
// Give up if the loop has multiple blocks or multiple backedges.
if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
return false;

Intrinsic::ID IntrinID;
Value *InitX;
Instruction *DefX = nullptr;
PHINode *CntPhi = nullptr;
Instruction *CntInst = nullptr;

uint64_t LoopThreshold;
if (!detectShiftUntilLessThanIdiom(CurLoop, *DL, IntrinID, InitX, CntInst,
CntPhi, DefX, LoopThreshold))
return false;

if (LoopThreshold == 2) {
// Treat as regular FFS.
return insertFFSIfProfitable(IntrinID, InitX, DefX, CntPhi, CntInst);
}

// Look for Floor Log2 Idiom.
if (LoopThreshold != 4)
return false;

// Abort if CntPhi is used outside of the loop.
for (User *U : CntPhi->users())
if (!CurLoop->contains(cast<Instruction>(U)))
return false;

// It is safe to assume Preheader exist as it was checked in
// parent function RunOnLoop.
BasicBlock *PH = CurLoop->getLoopPreheader();
auto *PreCondBB = PH->getSinglePredecessor();
if (!PreCondBB)
return false;
auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator());
if (!PreCondBI)
return false;

uint64_t PreLoopThreshold;
if (matchShiftULTCondition(PreCondBI, PH, PreLoopThreshold) != InitX ||
PreLoopThreshold != 2)
return false;

bool ZeroCheck = true;

// the loop has only 6 instructions:
// %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ]
// %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ]
// %shr = ashr %n.addr.0, 1
// %tobool = icmp ult %n.addr.0, C
// %inc = add nsw %i.0, 1
// br i1 %tobool
size_t IdiomCanonicalSize = 6;
if (!isProfitableToInsertFFS(IntrinID, InitX, ZeroCheck, IdiomCanonicalSize))
return false;

// log2(x) = w − 1 − clz(x)
transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX,
DefX->getDebugLoc(), ZeroCheck,
IsCntPhiUsedOutsideLoop);
/*IsCntPhiUsedOutsideLoop=*/false,
/*InsertSub=*/true);
return true;
}

Expand Down Expand Up @@ -1961,7 +2178,7 @@ static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val,
void LoopIdiomRecognize::transformLoopToCountable(
Intrinsic::ID IntrinID, BasicBlock *Preheader, Instruction *CntInst,
PHINode *CntPhi, Value *InitX, Instruction *DefX, const DebugLoc &DL,
bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) {
bool ZeroCheck, bool IsCntPhiUsedOutsideLoop, bool InsertSub) {
BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());

// Step 1: Insert the CTLZ/CTTZ instruction at the end of the preheader block
Expand Down Expand Up @@ -1991,6 +2208,8 @@ void LoopIdiomRecognize::transformLoopToCountable(
Type *CountTy = Count->getType();
Count = Builder.CreateSub(
ConstantInt::get(CountTy, CountTy->getIntegerBitWidth()), Count);
if (InsertSub)
Count = Builder.CreateSub(Count, ConstantInt::get(CountTy, 1));
Value *NewCount = Count;
if (IsCntPhiUsedOutsideLoop)
Count = Builder.CreateAdd(Count, ConstantInt::get(CountTy, 1));
Expand Down
16 changes: 4 additions & 12 deletions llvm/lib/Transforms/Utils/SCCPSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1295,24 +1295,16 @@ void SCCPInstVisitor::visitCastInst(CastInst &I) {
return (void)markConstant(&I, C);
}

if (I.getDestTy()->isIntegerTy() && I.getSrcTy()->isIntOrIntVectorTy()) {
// Ignore bitcasts, as they may change the number of vector elements.
if (I.getDestTy()->isIntegerTy() && I.getSrcTy()->isIntOrIntVectorTy() &&
I.getOpcode() != Instruction::BitCast) {
auto &LV = getValueState(&I);
ConstantRange OpRange =
getConstantRange(OpSt, I.getSrcTy(), /*UndefAllowed=*/false);

Type *DestTy = I.getDestTy();
// Vectors where all elements have the same known constant range are treated
// as a single constant range in the lattice. When bitcasting such vectors,
// there is a mis-match between the width of the lattice value (single
// constant range) and the original operands (vector). Go to overdefined in
// that case.
if (I.getOpcode() == Instruction::BitCast &&
I.getOperand(0)->getType()->isVectorTy() &&
OpRange.getBitWidth() < DL.getTypeSizeInBits(DestTy))
return (void)markOverdefined(&I);

ConstantRange Res =
OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy));
OpRange.castOp(I.getOpcode(), DestTy->getScalarSizeInBits());
mergeInValue(LV, &I, ValueLatticeElement::getRange(Res));
} else
markOverdefined(&I);
Expand Down
225 changes: 185 additions & 40 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,12 @@ class BoUpSLP {
return VectorizableTree.front()->Scalars;
}

/// Checks if the root graph node can be emitted with narrower bitwidth at
/// codegen and returns it signedness, if so.
bool isSignedMinBitwidthRootNode() const {
return MinBWs.at(VectorizableTree.front().get()).second;
}

/// Builds external uses of the vectorized scalars, i.e. the list of
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
/// ExternallyUsedValues contains additional list of external uses to handle
Expand Down Expand Up @@ -2453,6 +2459,90 @@ class BoUpSLP {
DeletedInstructions.insert(I);
}

/// Remove instructions from the parent function and clear the operands of \p
/// DeadVals instructions, marking for deletion trivially dead operands.
template <typename T>
void removeInstructionsAndOperands(ArrayRef<T *> DeadVals) {
SmallVector<WeakTrackingVH> DeadInsts;
for (T *V : DeadVals) {
auto *I = cast<Instruction>(V);
DeletedInstructions.insert(I);
}
for (T *V : DeadVals) {
if (!V)
continue;
auto *I = cast<Instruction>(V);
salvageDebugInfo(*I);
SmallVector<const TreeEntry *> Entries;
if (const TreeEntry *Entry = getTreeEntry(I)) {
Entries.push_back(Entry);
auto It = MultiNodeScalars.find(I);
if (It != MultiNodeScalars.end())
Entries.append(It->second.begin(), It->second.end());
}
for (Use &U : I->operands()) {
if (auto *OpI = dyn_cast_if_present<Instruction>(U.get());
OpI && !DeletedInstructions.contains(OpI) && OpI->hasOneUser() &&
wouldInstructionBeTriviallyDead(OpI, TLI) &&
(Entries.empty() || none_of(Entries, [&](const TreeEntry *Entry) {
return Entry->VectorizedValue == OpI;
})))
DeadInsts.push_back(OpI);
}
I->dropAllReferences();
}
for (T *V : DeadVals) {
auto *I = cast<Instruction>(V);
if (!I->getParent())
continue;
assert((I->use_empty() || all_of(I->uses(),
[&](Use &U) {
return isDeleted(
cast<Instruction>(U.getUser()));
})) &&
"trying to erase instruction with users.");
I->removeFromParent();
SE->forgetValue(I);
}
// Process the dead instruction list until empty.
while (!DeadInsts.empty()) {
Value *V = DeadInsts.pop_back_val();
Instruction *VI = cast_or_null<Instruction>(V);
if (!VI || !VI->getParent())
continue;
assert(isInstructionTriviallyDead(VI, TLI) &&
"Live instruction found in dead worklist!");
assert(VI->use_empty() && "Instructions with uses are not dead.");

// Don't lose the debug info while deleting the instructions.
salvageDebugInfo(*VI);

// Null out all of the instruction's operands to see if any operand
// becomes dead as we go.
for (Use &OpU : VI->operands()) {
Value *OpV = OpU.get();
if (!OpV)
continue;
OpU.set(nullptr);

if (!OpV->use_empty())
continue;

// If the operand is an instruction that became dead as we nulled out
// the operand, and if it is 'trivially' dead, delete it in a future
// loop iteration.
if (auto *OpI = dyn_cast<Instruction>(OpV))
if (!DeletedInstructions.contains(OpI) &&
isInstructionTriviallyDead(OpI, TLI))
DeadInsts.push_back(OpI);
}

VI->removeFromParent();
DeletedInstructions.insert(VI);
SE->forgetValue(VI);
}
}

/// Checks if the instruction was already analyzed for being possible
/// reduction root.
bool isAnalyzedReductionRoot(Instruction *I) const {
Expand Down Expand Up @@ -3987,6 +4077,10 @@ template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits {
BoUpSLP::~BoUpSLP() {
SmallVector<WeakTrackingVH> DeadInsts;
for (auto *I : DeletedInstructions) {
if (!I->getParent()) {
I->insertBefore(F->getEntryBlock().getTerminator());
continue;
}
for (Use &U : I->operands()) {
auto *Op = dyn_cast<Instruction>(U.get());
if (Op && !DeletedInstructions.count(Op) && Op->hasOneUser() &&
Expand Down Expand Up @@ -14075,11 +14169,8 @@ Value *BoUpSLP::vectorizeTree(
}
#endif
LLVM_DEBUG(dbgs() << "SLP: \tErasing scalar:" << *Scalar << ".\n");
eraseInstruction(cast<Instruction>(Scalar));
// Retain to-be-deleted instructions for some debug-info
// bookkeeping. NOTE: eraseInstruction only marks the instruction for
// deletion - instructions are not deleted until later.
RemovedInsts.push_back(cast<Instruction>(Scalar));
auto *I = cast<Instruction>(Scalar);
RemovedInsts.push_back(I);
}
}

Expand All @@ -14088,6 +14179,22 @@ Value *BoUpSLP::vectorizeTree(
if (auto *V = dyn_cast<Instruction>(VectorizableTree[0]->VectorizedValue))
V->mergeDIAssignID(RemovedInsts);

// Clear up reduction references, if any.
if (UserIgnoreList) {
for (Instruction *I : RemovedInsts) {
if (getTreeEntry(I)->Idx != 0)
continue;
I->replaceUsesWithIf(PoisonValue::get(I->getType()), [&](Use &U) {
return UserIgnoreList->contains(U.getUser());
});
}
}
// Retain to-be-deleted instructions for some debug-info bookkeeping and alias
// cache correctness.
// NOTE: removeInstructionAndOperands only marks the instruction for deletion
// - instructions are not deleted until later.
removeInstructionsAndOperands(ArrayRef(RemovedInsts));

Builder.ClearInsertionPoint();
InstrElementSize.clear();

Expand Down Expand Up @@ -16137,15 +16244,18 @@ bool SLPVectorizerPass::vectorizeStores(
Res.first = Idx;
Res.second.emplace(Idx, 0);
};
StoreInst *PrevStore = Stores.front();
Type *PrevValTy = nullptr;
for (auto [I, SI] : enumerate(Stores)) {
if (R.isDeleted(SI))
continue;
if (!PrevValTy)
PrevValTy = SI->getValueOperand()->getType();
// Check that we do not try to vectorize stores of different types.
if (PrevStore->getValueOperand()->getType() !=
SI->getValueOperand()->getType()) {
if (PrevValTy != SI->getValueOperand()->getType()) {
for (auto &Set : SortedStores)
TryToVectorize(Set.second);
SortedStores.clear();
PrevStore = SI;
PrevValTy = SI->getValueOperand()->getType();
}
FillStoresSet(I, SI);
}
Expand Down Expand Up @@ -17019,9 +17129,12 @@ class HorizontalReduction {
Value *VectorizedTree = nullptr;
bool CheckForReusedReductionOps = false;
// Try to vectorize elements based on their type.
SmallVector<InstructionsState> States;
for (ArrayRef<Value *> RV : ReducedVals)
States.push_back(getSameOpcode(RV, TLI));
for (unsigned I = 0, E = ReducedVals.size(); I < E; ++I) {
ArrayRef<Value *> OrigReducedVals = ReducedVals[I];
InstructionsState S = getSameOpcode(OrigReducedVals, TLI);
InstructionsState S = States[I];
SmallVector<Value *> Candidates;
Candidates.reserve(2 * OrigReducedVals.size());
DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size());
Expand Down Expand Up @@ -17346,14 +17459,11 @@ class HorizontalReduction {
Value *ReducedSubTree =
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
if (ReducedSubTree->getType() != VL.front()->getType()) {
ReducedSubTree = Builder.CreateIntCast(
ReducedSubTree, VL.front()->getType(), any_of(VL, [&](Value *R) {
KnownBits Known = computeKnownBits(
R, cast<Instruction>(ReductionOps.front().front())
->getModule()
->getDataLayout());
return !Known.isNonNegative();
}));
assert(ReducedSubTree->getType() != VL.front()->getType() &&
"Expected different reduction type.");
ReducedSubTree =
Builder.CreateIntCast(ReducedSubTree, VL.front()->getType(),
V.isSignedMinBitwidthRootNode());
}

// Improved analysis for add/fadd/xor reductions with same scale factor
Expand Down Expand Up @@ -17518,16 +17628,8 @@ class HorizontalReduction {
Value *P = PoisonValue::get(Ignore->getType());
Ignore->replaceAllUsesWith(P);
}
auto *I = dyn_cast<Instruction>(Ignore);
// Clear the operands with non single use. Allows better
// vectorization.
for (unsigned Idx : seq<unsigned>(I->getNumOperands())) {
Value *Op = I->getOperand(Idx);
if (!Op->hasOneUse())
I->setOperand(Idx, PoisonValue::get(Op->getType()));
}
V.eraseInstruction(I);
}
V.removeInstructionsAndOperands(RdxOps);
}
} else if (!CheckForReusedReductionOps) {
for (ReductionOpsType &RdxOps : ReductionOps)
Expand Down Expand Up @@ -18075,6 +18177,8 @@ bool SLPVectorizerPass::vectorizeHorReduction(
Stack.emplace(I, Level);
continue;
}
if (R.isDeleted(Inst))
continue;
} else {
// We could not vectorize `Inst` so try to use it as a future seed.
if (!TryAppendToPostponedInsts(Inst)) {
Expand Down Expand Up @@ -18160,15 +18264,28 @@ static bool tryToVectorizeSequence(

// Try to vectorize elements base on their type.
SmallVector<T *> Candidates;
for (auto *IncIt = Incoming.begin(), *E = Incoming.end(); IncIt != E;) {
SmallVector<T *> VL;
for (auto *IncIt = Incoming.begin(), *E = Incoming.end(); IncIt != E;
VL.clear()) {
// Look for the next elements with the same type, parent and operand
// kinds.
auto *I = dyn_cast<Instruction>(*IncIt);
if (!I || R.isDeleted(I)) {
++IncIt;
continue;
}
auto *SameTypeIt = IncIt;
while (SameTypeIt != E && AreCompatible(*SameTypeIt, *IncIt))
while (SameTypeIt != E && (!isa<Instruction>(*SameTypeIt) ||
R.isDeleted(cast<Instruction>(*SameTypeIt)) ||
AreCompatible(*SameTypeIt, *IncIt))) {
auto *I = dyn_cast<Instruction>(*SameTypeIt);
++SameTypeIt;
if (I && !R.isDeleted(I))
VL.push_back(cast<T>(I));
}

// Try to vectorize them.
unsigned NumElts = (SameTypeIt - IncIt);
unsigned NumElts = VL.size();
LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at nodes ("
<< NumElts << ")\n");
// The vectorization is a 3-state attempt:
Expand All @@ -18180,10 +18297,15 @@ static bool tryToVectorizeSequence(
// 3. Final attempt to try to vectorize all instructions with the
// same/alternate ops only, this may result in some extra final
// vectorization.
if (NumElts > 1 &&
TryToVectorizeHelper(ArrayRef(IncIt, NumElts), MaxVFOnly)) {
if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(VL), MaxVFOnly)) {
// Success start over because instructions might have been changed.
Changed = true;
VL.swap(Candidates);
Candidates.clear();
for (T *V : VL) {
if (auto *I = dyn_cast<Instruction>(V); I && !R.isDeleted(I))
Candidates.push_back(V);
}
} else {
/// \Returns the minimum number of elements that we will attempt to
/// vectorize.
Expand All @@ -18194,7 +18316,10 @@ static bool tryToVectorizeSequence(
if (NumElts < GetMinNumElements(*IncIt) &&
(Candidates.empty() ||
Candidates.front()->getType() == (*IncIt)->getType())) {
Candidates.append(IncIt, std::next(IncIt, NumElts));
for (T *V : VL) {
if (auto *I = dyn_cast<Instruction>(V); I && !R.isDeleted(I))
Candidates.push_back(V);
}
}
}
// Final attempt to vectorize instructions with the same types.
Expand All @@ -18205,13 +18330,26 @@ static bool tryToVectorizeSequence(
Changed = true;
} else if (MaxVFOnly) {
// Try to vectorize using small vectors.
for (auto *It = Candidates.begin(), *End = Candidates.end();
It != End;) {
SmallVector<T *> VL;
for (auto *It = Candidates.begin(), *End = Candidates.end(); It != End;
VL.clear()) {
auto *I = dyn_cast<Instruction>(*It);
if (!I || R.isDeleted(I)) {
++It;
continue;
}
auto *SameTypeIt = It;
while (SameTypeIt != End && AreCompatible(*SameTypeIt, *It))
while (SameTypeIt != End &&
(!isa<Instruction>(*SameTypeIt) ||
R.isDeleted(cast<Instruction>(*SameTypeIt)) ||
AreCompatible(*SameTypeIt, *It))) {
auto *I = dyn_cast<Instruction>(*SameTypeIt);
++SameTypeIt;
unsigned NumElts = (SameTypeIt - It);
if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(It, NumElts),
if (I && !R.isDeleted(I))
VL.push_back(cast<T>(I));
}
unsigned NumElts = VL.size();
if (NumElts > 1 && TryToVectorizeHelper(ArrayRef(VL),
/*MaxVFOnly=*/false))
Changed = true;
It = SameTypeIt;
Expand Down Expand Up @@ -18485,7 +18623,7 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
}
return false;
};
auto AreCompatiblePHIs = [&PHIToOpcodes, this](Value *V1, Value *V2) {
auto AreCompatiblePHIs = [&PHIToOpcodes, this, &R](Value *V1, Value *V2) {
if (V1 == V2)
return true;
if (V1->getType() != V2->getType())
Expand All @@ -18500,6 +18638,8 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
continue;
if (auto *I1 = dyn_cast<Instruction>(Opcodes1[I]))
if (auto *I2 = dyn_cast<Instruction>(Opcodes2[I])) {
if (R.isDeleted(I1) || R.isDeleted(I2))
return false;
if (I1->getParent() != I2->getParent())
return false;
InstructionsState S = getSameOpcode({I1, I2}, *TLI);
Expand Down Expand Up @@ -18720,8 +18860,13 @@ bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) {
// are trying to vectorize the index computations, so the maximum number of
// elements is based on the size of the index expression, rather than the
// size of the GEP itself (the target's pointer size).
auto *It = find_if(Entry.second, [&](GetElementPtrInst *GEP) {
return !R.isDeleted(GEP);
});
if (It == Entry.second.end())
continue;
unsigned MaxVecRegSize = R.getMaxVecRegSize();
unsigned EltSize = R.getVectorElementSize(*Entry.second[0]->idx_begin());
unsigned EltSize = R.getVectorElementSize(*(*It)->idx_begin());
if (MaxVecRegSize < EltSize)
continue;

Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,8 +1491,7 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
Value *V0, *V1;
ArrayRef<int> OldMask;
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
m_Mask(OldMask))))
if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
return false;

auto *C0 = dyn_cast<CastInst>(V0);
Expand Down Expand Up @@ -1551,11 +1550,13 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
// Try to replace a castop with a shuffle if the shuffle is not costly.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;

InstructionCost OldCost =
InstructionCost CostC0 =
TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind) +
TTI::CastContextHint::None, CostKind);
InstructionCost CostC1 =
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind);
InstructionCost OldCost = CostC0 + CostC1;
OldCost +=
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
OldMask, CostKind, 0, nullptr, std::nullopt, &I);
Expand All @@ -1564,6 +1565,10 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
TTI::CastContextHint::None, CostKind);
if (!C0->hasOneUse())
NewCost += CostC0;
if (!C1->hasOneUse())
NewCost += CostC1;

LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
;; Test that dbg.assigns linked to the the scalar stores to quad get linked to
;; the vector store that replaces them.

; CHECK: #dbg_assign(float undef, ![[VAR:[0-9]+]], !DIExpression(DW_OP_LLVM_fragment, 0, 32), ![[ID:[0-9]+]], ptr %arrayidx, !DIExpression(),
; CHECK: #dbg_assign(float undef, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 32, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 4),
; CHECK: #dbg_assign(float undef, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 64, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 8),
; CHECK: #dbg_assign(float poison, ![[VAR:[0-9]+]], !DIExpression(DW_OP_LLVM_fragment, 0, 32), ![[ID:[0-9]+]], ptr %arrayidx, !DIExpression(),
; CHECK: #dbg_assign(float poison, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 32, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 4),
; CHECK: #dbg_assign(float poison, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 64, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 8),
; CHECK: store <4 x float> {{.*}} !DIAssignID ![[ID]]
; CHECK: #dbg_assign(float undef, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 96, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 12),
; CHECK: #dbg_assign(float poison, ![[VAR]], !DIExpression(DW_OP_LLVM_fragment, 96, 32), ![[ID]], ptr %quad, !DIExpression(DW_OP_plus_uconst, 12),

target triple = "x86_64-unknown-unknown"

Expand Down
Loading