164 changes: 164 additions & 0 deletions libc/src/math/generic/log10f16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//===-- Half-precision log10(x) function ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/log10f16.h"
#include "expxf16.h"
#include "hdr/errno_macros.h"
#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/PolyEval.h"
#include "src/__support/FPUtil/cast.h"
#include "src/__support/FPUtil/except_value_utils.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h"
#include "src/__support/macros/properties/cpu_features.h"

namespace LIBC_NAMESPACE_DECL {

#ifdef LIBC_TARGET_CPU_HAS_FMA
static constexpr size_t N_LOG10F16_EXCEPTS = 11;
#else
static constexpr size_t N_LOG10F16_EXCEPTS = 17;
#endif

static constexpr fputil::ExceptValues<float16, N_LOG10F16_EXCEPTS>
LOG10F16_EXCEPTS = {{
// (input, RZ output, RU offset, RD offset, RN offset)
// x = 0x1.e3cp-3, log10f16(x) = -0x1.40cp-1 (RZ)
{0x338fU, 0xb903U, 0U, 1U, 0U},
// x = 0x1.fep-3, log10f16(x) = -0x1.35p-1 (RZ)
{0x33f8U, 0xb8d4U, 0U, 1U, 1U},
#ifndef LIBC_TARGET_CPU_HAS_FMA
// x = 0x1.394p-1, log10f16(x) = -0x1.b4cp-3 (RZ)
{0x38e5U, 0xb2d3U, 0U, 1U, 1U},
#endif
// x = 0x1.ea8p-1, log10f16(x) = -0x1.31p-6 (RZ)
{0x3baaU, 0xa4c4U, 0U, 1U, 1U},
// x = 0x1.ebp-1, log10f16(x) = -0x1.29cp-6 (RZ)
{0x3bacU, 0xa4a7U, 0U, 1U, 1U},
// x = 0x1.f3p-1, log10f16(x) = -0x1.6dcp-7 (RZ)
{0x3bccU, 0xa1b7U, 0U, 1U, 1U},
// x = 0x1.f38p-1, log10f16(x) = -0x1.5f8p-7 (RZ)
#ifndef LIBC_TARGET_CPU_HAS_FMA
{0x3bceU, 0xa17eU, 0U, 1U, 1U},
// x = 0x1.fd8p-1, log10f16(x) = -0x1.168p-9 (RZ)
{0x3bf6U, 0x985aU, 0U, 1U, 1U},
// x = 0x1.ff8p-1, log10f16(x) = -0x1.bccp-12 (RZ)
{0x3bfeU, 0x8ef3U, 0U, 1U, 1U},
// x = 0x1.374p+0, log10f16(x) = 0x1.5b8p-4 (RZ)
{0x3cddU, 0x2d6eU, 1U, 0U, 1U},
// x = 0x1.3ecp+1, log10f16(x) = 0x1.958p-2 (RZ)
{0x40fbU, 0x3656U, 1U, 0U, 1U},
#endif
// x = 0x1.4p+3, log10f16(x) = 0x1p+0 (RZ)
{0x4900U, 0x3c00U, 0U, 0U, 0U},
// x = 0x1.9p+6, log10f16(x) = 0x1p+1 (RZ)
{0x5640U, 0x4000U, 0U, 0U, 0U},
// x = 0x1.f84p+6, log10f16(x) = 0x1.0ccp+1 (RZ)
{0x57e1U, 0x4033U, 1U, 0U, 0U},
// x = 0x1.f4p+9, log10f16(x) = 0x1.8p+1 (RZ)
{0x63d0U, 0x4200U, 0U, 0U, 0U},
// x = 0x1.388p+13, log10f16(x) = 0x1p+2 (RZ)
{0x70e2U, 0x4400U, 0U, 0U, 0U},
// x = 0x1.674p+13, log10f16(x) = 0x1.03cp+2 (RZ)
{0x719dU, 0x440fU, 1U, 0U, 0U},
}};

LLVM_LIBC_FUNCTION(float16, log10f16, (float16 x)) {
using FPBits = fputil::FPBits<float16>;
FPBits x_bits(x);

uint16_t x_u = x_bits.uintval();

// If x <= 0, or x is 1, or x is +inf, or x is NaN.
if (LIBC_UNLIKELY(x_u == 0U || x_u == 0x3c00U || x_u >= 0x7c00U)) {
// log10(NaN) = NaN
if (x_bits.is_nan()) {
if (x_bits.is_signaling_nan()) {
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}

return x;
}

// log10(+/-0) = −inf
if ((x_u & 0x7fffU) == 0U) {
fputil::raise_except_if_required(FE_DIVBYZERO);
return FPBits::inf(Sign::NEG).get_val();
}

if (x_u == 0x3c00U)
return FPBits::zero().get_val();

// When x < 0.
if (x_u > 0x8000U) {
fputil::set_errno_if_required(EDOM);
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}

// log10(+inf) = +inf
return FPBits::inf().get_val();
}

if (auto r = LOG10F16_EXCEPTS.lookup(x_u); LIBC_UNLIKELY(r.has_value()))
return r.value();

// To compute log10(x), we perform the following range reduction:
// x = 2^m * 1.mant,
// log10(x) = m * log10(2) + log10(1.mant).
// To compute log10(1.mant), let f be the highest 6 bits including the hidden
// bit, and d be the difference (1.mant - f), i.e., the remaining 5 bits of
// the mantissa, then:
// log10(1.mant) = log10(f) + log10(1.mant / f)
// = log10(f) + log10(1 + d/f)
// since d/f is sufficiently small.
// We store log10(f) and 1/f in the lookup tables LOG10F_F and ONE_OVER_F_F
// respectively.

int m = -FPBits::EXP_BIAS;

// When x is subnormal, normalize it.
if ((x_u & FPBits::EXP_MASK) == 0U) {
// Can't pass an integer to fputil::cast directly.
constexpr float NORMALIZE_EXP = 1U << FPBits::FRACTION_LEN;
x_bits = FPBits(x_bits.get_val() * fputil::cast<float16>(NORMALIZE_EXP));
x_u = x_bits.uintval();
m -= FPBits::FRACTION_LEN;
}

uint16_t mant = x_bits.get_mantissa();
// Leading 10 - 5 = 5 bits of the mantissa.
int f = mant >> 5;
// Unbiased exponent.
m += x_u >> FPBits::FRACTION_LEN;

// Set bits to 1.mant instead of 2^m * 1.mant.
x_bits.set_biased_exponent(FPBits::EXP_BIAS);
float mant_f = x_bits.get_val();
// v = 1.mant * 1/f - 1 = d/f
float v = fputil::multiply_add(mant_f, ONE_OVER_F_F[f], -1.0f);

// Degree-3 minimax polynomial generated by Sollya with the following
// commands:
// > display = hexadecimal;
// > P = fpminimax(log10(1 + x)/x, 2, [|SG...|], [-2^-5, 2^-5]);
// > x * P;
float log10p1_d_over_f =
v * fputil::polyeval(v, 0x1.bcb7bp-2f, -0x1.bce168p-3f, 0x1.28acb8p-3f);
// log10(1.mant) = log10(f) + log10(1 + d/f)
float log10_1_mant = LOG10F_F[f] + log10p1_d_over_f;
return fputil::cast<float16>(
fputil::multiply_add(static_cast<float>(m), LOG10F_2, log10_1_mant));
}

} // namespace LIBC_NAMESPACE_DECL
149 changes: 149 additions & 0 deletions libc/src/math/generic/log2f16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//===-- Half-precision log2(x) function -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/log2f16.h"
#include "expxf16.h"
#include "hdr/errno_macros.h"
#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/FEnvImpl.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/PolyEval.h"
#include "src/__support/FPUtil/cast.h"
#include "src/__support/FPUtil/except_value_utils.h"
#include "src/__support/FPUtil/multiply_add.h"
#include "src/__support/common.h"
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h"
#include "src/__support/macros/properties/cpu_features.h"

namespace LIBC_NAMESPACE_DECL {

#ifdef LIBC_TARGET_CPU_HAS_FMA
static constexpr size_t N_LOG2F16_EXCEPTS = 2;
#else
static constexpr size_t N_LOG2F16_EXCEPTS = 9;
#endif

static constexpr fputil::ExceptValues<float16, N_LOG2F16_EXCEPTS>
LOG2F16_EXCEPTS = {{
// (input, RZ output, RU offset, RD offset, RN offset)
#ifndef LIBC_TARGET_CPU_HAS_FMA
// x = 0x1.224p-1, log2f16(x) = -0x1.a34p-1 (RZ)
{0x3889U, 0xba8dU, 0U, 1U, 0U},
// x = 0x1.e34p-1, log2f16(x) = -0x1.558p-4 (RZ)
{0x3b8dU, 0xad56U, 0U, 1U, 0U},
#endif
// x = 0x1.e8cp-1, log2f16(x) = -0x1.128p-4 (RZ)
{0x3ba3U, 0xac4aU, 0U, 1U, 0U},
#ifndef LIBC_TARGET_CPU_HAS_FMA
// x = 0x1.f98p-1, log2f16(x) = -0x1.2ep-6 (RZ)
{0x3be6U, 0xa4b8U, 0U, 1U, 0U},
// x = 0x1.facp-1, log2f16(x) = -0x1.e7p-7 (RZ)
{0x3bebU, 0xa39cU, 0U, 1U, 1U},
#endif
// x = 0x1.fb4p-1, log2f16(x) = -0x1.b88p-7 (RZ)
{0x3bedU, 0xa2e2U, 0U, 1U, 1U},
#ifndef LIBC_TARGET_CPU_HAS_FMA
// x = 0x1.fecp-1, log2f16(x) = -0x1.cep-9 (RZ)
{0x3bfbU, 0x9b38U, 0U, 1U, 1U},
// x = 0x1.ffcp-1, log2f16(x) = -0x1.714p-11 (RZ)
{0x3bffU, 0x91c5U, 0U, 1U, 1U},
// x = 0x1.224p+0, log2f16(x) = 0x1.72cp-3 (RZ)
{0x3c89U, 0x31cbU, 1U, 0U, 1U},
#endif
}};

LLVM_LIBC_FUNCTION(float16, log2f16, (float16 x)) {
using FPBits = fputil::FPBits<float16>;
FPBits x_bits(x);

uint16_t x_u = x_bits.uintval();

// If x <= 0, or x is 1, or x is +inf, or x is NaN.
if (LIBC_UNLIKELY(x_u == 0U || x_u == 0x3c00U || x_u >= 0x7c00U)) {
// log2(NaN) = NaN
if (x_bits.is_nan()) {
if (x_bits.is_signaling_nan()) {
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}

return x;
}

// log2(+/-0) = −inf
if ((x_u & 0x7fffU) == 0U) {
fputil::raise_except_if_required(FE_DIVBYZERO);
return FPBits::inf(Sign::NEG).get_val();
}

if (x_u == 0x3c00U)
return FPBits::zero().get_val();

// When x < 0.
if (x_u > 0x8000U) {
fputil::set_errno_if_required(EDOM);
fputil::raise_except_if_required(FE_INVALID);
return FPBits::quiet_nan().get_val();
}

// log2(+inf) = +inf
return FPBits::inf().get_val();
}

if (auto r = LOG2F16_EXCEPTS.lookup(x_u); LIBC_UNLIKELY(r.has_value()))
return r.value();

// To compute log2(x), we perform the following range reduction:
// x = 2^m * 1.mant,
// log2(x) = m + log2(1.mant).
// To compute log2(1.mant), let f be the highest 6 bits including the hidden
// bit, and d be the difference (1.mant - f), i.e., the remaining 5 bits of
// the mantissa, then:
// log2(1.mant) = log2(f) + log2(1.mant / f)
// = log2(f) + log2(1 + d/f)
// since d/f is sufficiently small.
// We store log2(f) and 1/f in the lookup tables LOG2F_F and ONE_OVER_F_F
// respectively.

int m = -FPBits::EXP_BIAS;

// When x is subnormal, normalize it.
if ((x_u & FPBits::EXP_MASK) == 0U) {
// Can't pass an integer to fputil::cast directly.
constexpr float NORMALIZE_EXP = 1U << FPBits::FRACTION_LEN;
x_bits = FPBits(x_bits.get_val() * fputil::cast<float16>(NORMALIZE_EXP));
x_u = x_bits.uintval();
m -= FPBits::FRACTION_LEN;
}

uint16_t mant = x_bits.get_mantissa();
// Leading 10 - 5 = 5 bits of the mantissa.
int f = mant >> 5;
// Unbiased exponent.
m += x_u >> FPBits::FRACTION_LEN;

// Set bits to 1.mant instead of 2^m * 1.mant.
x_bits.set_biased_exponent(FPBits::EXP_BIAS);
float mant_f = x_bits.get_val();
// v = 1.mant * 1/f - 1 = d/f
float v = fputil::multiply_add(mant_f, ONE_OVER_F_F[f], -1.0f);

// Degree-3 minimax polynomial generated by Sollya with the following
// commands:
// > display = hexadecimal;
// > P = fpminimax(log2(1 + x)/x, 2, [|SG...|], [-2^-5, 2^-5]);
// > x * P;
float log2p1_d_over_f =
v * fputil::polyeval(v, 0x1.715476p+0f, -0x1.71771ap-1f, 0x1.ecb38ep-2f);
// log2(1.mant) = log2(f) + log2(1 + d/f)
float log2_1_mant = LOG2F_F[f] + log2p1_d_over_f;
return fputil::cast<float16>(static_cast<float>(m) + log2_1_mant);
}

} // namespace LIBC_NAMESPACE_DECL
2 changes: 1 addition & 1 deletion libc/src/math/generic/logf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ LLVM_LIBC_FUNCTION(float16, logf16, (float16 x)) {
// log(1.mant) = log(f) + log(1.mant / f)
// = log(f) + log(1 + d/f)
// since d/f is sufficiently small.
// We store log(f) and 1/f in the lookup tables LOGF_F and ONE_OVER_F
// We store log(f) and 1/f in the lookup tables LOGF_F and ONE_OVER_F_F
// respectively.

int m = -FPBits::EXP_BIAS;
Expand Down
21 changes: 21 additions & 0 deletions libc/src/math/log10f16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===-- Implementation header for log10f16 ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_LOG10F16_H
#define LLVM_LIBC_SRC_MATH_LOG10F16_H

#include "src/__support/macros/config.h"
#include "src/__support/macros/properties/types.h"

namespace LIBC_NAMESPACE_DECL {

float16 log10f16(float16 x);

} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC_MATH_LOG10F16_H
21 changes: 21 additions & 0 deletions libc/src/math/log2f16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===-- Implementation header for log2f16 -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC_MATH_LOG2F16_H
#define LLVM_LIBC_SRC_MATH_LOG2F16_H

#include "src/__support/macros/config.h"
#include "src/__support/macros/properties/types.h"

namespace LIBC_NAMESPACE_DECL {

float16 log2f16(float16 x);

} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC_MATH_LOG2F16_H
20 changes: 16 additions & 4 deletions libc/src/stdio/scanf_core/int_converter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,24 @@ int convert_int(Reader *reader, const FormatSection &to_conv) {

if (to_lower(cur_char) == 'x') {
// This is a valid hex prefix.

is_number = false;
// A valid hex prefix is not necessarily a valid number. For the
// conversion to be valid it needs to use all of the characters it
// consumes. From the standard:
// 7.23.6.2 paragraph 9: "An input item is defined as the longest
// sequence of input characters which does not exceed any specified
// field width and which is, or is a prefix of, a matching input
// sequence."
// 7.23.6.2 paragraph 10: "If the input item is not a matching sequence,
// the execution of the directive fails: this condition is a matching
// failure"
base = 16;
if (max_width > 1) {
--max_width;
cur_char = reader->getc();
} else {
write_int_with_length(0, to_conv);
return READ_OK;
return MATCHING_FAILURE;
}

} else {
Expand Down Expand Up @@ -198,6 +209,9 @@ int convert_int(Reader *reader, const FormatSection &to_conv) {
// last one back.
reader->ungetc(cur_char);

if (!is_number)
return MATCHING_FAILURE;

if (has_overflow) {
write_int_with_length(MAX, to_conv);
} else {
Expand All @@ -207,8 +221,6 @@ int convert_int(Reader *reader, const FormatSection &to_conv) {
write_int_with_length(result, to_conv);
}

if (!is_number)
return MATCHING_FAILURE;
return READ_OK;
}

Expand Down
22 changes: 22 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,17 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
log2f16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
log2f16_test.cpp
DEPENDS
libc.src.math.log2f16
)

add_fp_unittest(
log10_test
NEED_MPFR
Expand All @@ -1835,6 +1846,17 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
log10f16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
log10f16_test.cpp
DEPENDS
libc.src.math.log10f16
)

add_fp_unittest(
log1p_test
NEED_MPFR
Expand Down
40 changes: 40 additions & 0 deletions libc/test/src/math/log10f16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===-- Exhaustive test for log10f16 --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/log10f16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

using LlvmLibcLog10f16Test = LIBC_NAMESPACE::testing::FPTest<float16>;

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

// Range: [0, Inf];
static constexpr uint16_t POS_START = 0x0000U;
static constexpr uint16_t POS_STOP = 0x7c00U;

// Range: [-Inf, 0];
static constexpr uint16_t NEG_START = 0x8000U;
static constexpr uint16_t NEG_STOP = 0xfc00U;

TEST_F(LlvmLibcLog10f16Test, PositiveRange) {
for (uint16_t v = POS_START; v <= POS_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Log10, x,
LIBC_NAMESPACE::log10f16(x), 0.5);
}
}

TEST_F(LlvmLibcLog10f16Test, NegativeRange) {
for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Log10, x,
LIBC_NAMESPACE::log10f16(x), 0.5);
}
}
40 changes: 40 additions & 0 deletions libc/test/src/math/log2f16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===-- Exhaustive test for log2f16 ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "src/math/log2f16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
#include "utils/MPFRWrapper/MPFRUtils.h"

using LlvmLibcLog2f16Test = LIBC_NAMESPACE::testing::FPTest<float16>;

namespace mpfr = LIBC_NAMESPACE::testing::mpfr;

// Range: [0, Inf];
static constexpr uint16_t POS_START = 0x0000U;
static constexpr uint16_t POS_STOP = 0x7c00U;

// Range: [-Inf, 0];
static constexpr uint16_t NEG_START = 0x8000U;
static constexpr uint16_t NEG_STOP = 0xfc00U;

TEST_F(LlvmLibcLog2f16Test, PositiveRange) {
for (uint16_t v = POS_START; v <= POS_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Log2, x,
LIBC_NAMESPACE::log2f16(x), 0.5);
}
}

TEST_F(LlvmLibcLog2f16Test, NegativeRange) {
for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) {
float16 x = FPBits(v).get_val();
EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Log2, x,
LIBC_NAMESPACE::log2f16(x), 0.5);
}
}
26 changes: 26 additions & 0 deletions libc/test/src/math/smoke/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3595,6 +3595,19 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
log2f16_test
SUITE
libc-math-smoke-tests
SRCS
log2f16_test.cpp
DEPENDS
libc.hdr.fenv_macros
libc.src.errno.errno
libc.src.math.log2f16
libc.src.__support.FPUtil.cast
)

add_fp_unittest(
log10_test
SUITE
Expand All @@ -3619,6 +3632,19 @@ add_fp_unittest(
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
log10f16_test
SUITE
libc-math-smoke-tests
SRCS
log10f16_test.cpp
DEPENDS
libc.hdr.fenv_macros
libc.src.errno.errno
libc.src.math.log10f16
libc.src.__support.FPUtil.cast
)

add_fp_unittest(
log1p_test
SUITE
Expand Down
50 changes: 50 additions & 0 deletions libc/test/src/math/smoke/log10f16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//===-- Unittests for log10f16 --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/cast.h"
#include "src/errno/libc_errno.h"
#include "src/math/log10f16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"

using LlvmLibcLog10f16Test = LIBC_NAMESPACE::testing::FPTest<float16>;

TEST_F(LlvmLibcLog10f16Test, SpecialNumbers) {
LIBC_NAMESPACE::libc_errno = 0;

EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::log10f16(aNaN));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_WITH_EXCEPTION(aNaN, LIBC_NAMESPACE::log10f16(sNaN), FE_INVALID);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::log10f16(inf));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::log10f16(neg_inf));
EXPECT_MATH_ERRNO(EDOM);

EXPECT_FP_EQ_WITH_EXCEPTION_ALL_ROUNDING(
neg_inf, LIBC_NAMESPACE::log10f16(zero), FE_DIVBYZERO);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_WITH_EXCEPTION_ALL_ROUNDING(
neg_inf, LIBC_NAMESPACE::log10f16(neg_zero), FE_DIVBYZERO);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(
zero,
LIBC_NAMESPACE::log10f16(LIBC_NAMESPACE::fputil::cast<float16>(1.0)));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(
aNaN,
LIBC_NAMESPACE::log10f16(LIBC_NAMESPACE::fputil::cast<float16>(-1.0)));
EXPECT_MATH_ERRNO(EDOM);
}
50 changes: 50 additions & 0 deletions libc/test/src/math/smoke/log2f16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//===-- Unittests for log2f16 ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "hdr/fenv_macros.h"
#include "src/__support/FPUtil/cast.h"
#include "src/errno/libc_errno.h"
#include "src/math/log2f16.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"

using LlvmLibcLog2f16Test = LIBC_NAMESPACE::testing::FPTest<float16>;

TEST_F(LlvmLibcLog2f16Test, SpecialNumbers) {
LIBC_NAMESPACE::libc_errno = 0;

EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::log2f16(aNaN));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_WITH_EXCEPTION(aNaN, LIBC_NAMESPACE::log2f16(sNaN), FE_INVALID);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::log2f16(inf));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::log2f16(neg_inf));
EXPECT_MATH_ERRNO(EDOM);

EXPECT_FP_EQ_WITH_EXCEPTION_ALL_ROUNDING(
neg_inf, LIBC_NAMESPACE::log2f16(zero), FE_DIVBYZERO);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_WITH_EXCEPTION_ALL_ROUNDING(
neg_inf, LIBC_NAMESPACE::log2f16(neg_zero), FE_DIVBYZERO);
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(
zero,
LIBC_NAMESPACE::log2f16(LIBC_NAMESPACE::fputil::cast<float16>(1.0)));
EXPECT_MATH_ERRNO(0);

EXPECT_FP_EQ_ALL_ROUNDING(
aNaN,
LIBC_NAMESPACE::log2f16(LIBC_NAMESPACE::fputil::cast<float16>(-1.0)));
EXPECT_MATH_ERRNO(EDOM);
}
36 changes: 26 additions & 10 deletions libc/test/src/stdio/sscanf_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,25 @@ TEST(LlvmLibcSScanfTest, IntConvMaxLengthTests) {
EXPECT_EQ(ret_val, 1);
EXPECT_EQ(result, 0);

result = -999;

// 0x is a valid prefix, but not a valid number. This should be a matching
// failure and should not modify the values.
ret_val = LIBC_NAMESPACE::sscanf("0x1", "%2i", &result);
EXPECT_EQ(ret_val, 1);
EXPECT_EQ(result, 0);
EXPECT_EQ(ret_val, 0);
EXPECT_EQ(result, -999);

ret_val = LIBC_NAMESPACE::sscanf("-0x1", "%3i", &result);
EXPECT_EQ(ret_val, 0);
EXPECT_EQ(result, -999);

ret_val = LIBC_NAMESPACE::sscanf("0x1", "%3i", &result);
EXPECT_EQ(ret_val, 1);
EXPECT_EQ(result, 0);
EXPECT_EQ(result, 1);

ret_val = LIBC_NAMESPACE::sscanf("-0x1", "%4i", &result);
EXPECT_EQ(ret_val, 1);
EXPECT_EQ(result, -1);

ret_val = LIBC_NAMESPACE::sscanf("-0x123", "%4i", &result);
EXPECT_EQ(ret_val, 1);
Expand Down Expand Up @@ -212,7 +224,7 @@ TEST(LlvmLibcSScanfTest, IntConvNoWriteTests) {
EXPECT_EQ(result, 0);

ret_val = LIBC_NAMESPACE::sscanf("0x1", "%*2i", &result);
EXPECT_EQ(ret_val, 1);
EXPECT_EQ(ret_val, 0);
EXPECT_EQ(result, 0);

ret_val = LIBC_NAMESPACE::sscanf("a", "%*i", &result);
Expand Down Expand Up @@ -679,13 +691,17 @@ TEST(LlvmLibcSScanfTest, CombinedConv) {
EXPECT_EQ(result, 123);
ASSERT_STREQ(buffer, "abc");

result = -1;

// 0x is a valid prefix, but not a valid number. This should be a matching
// failure and should not modify the values.
ret_val = LIBC_NAMESPACE::sscanf("0xZZZ", "%i%s", &result, buffer);
EXPECT_EQ(ret_val, 2);
EXPECT_EQ(result, 0);
ASSERT_STREQ(buffer, "ZZZ");
EXPECT_EQ(ret_val, 0);
EXPECT_EQ(result, -1);
ASSERT_STREQ(buffer, "abc");

ret_val = LIBC_NAMESPACE::sscanf("0xZZZ", "%X%s", &result, buffer);
EXPECT_EQ(ret_val, 2);
EXPECT_EQ(result, 0);
ASSERT_STREQ(buffer, "ZZZ");
EXPECT_EQ(ret_val, 0);
EXPECT_EQ(result, -1);
ASSERT_STREQ(buffer, "abc");
}
2 changes: 1 addition & 1 deletion lld/test/MachO/objc-category-merging-minimal.s
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

############ Test merging skipped due to invalid category name ############
# Modify __OBJC_$_CATEGORY_MyBaseClass_$_Category01's name to point to L_OBJC_IMAGE_INFO+3
# RUN: sed -E '/^__OBJC_\$_CATEGORY_MyBaseClass_\$_Category01:/ { n; s/^[ \t]*\.quad[ \t]+l_OBJC_CLASS_NAME_$/\t.quad\tL_OBJC_IMAGE_INFO+3/}' merge_cat_minimal.s > merge_cat_minimal_bad_name.s
# RUN: awk '/^__OBJC_\$_CATEGORY_MyBaseClass_\$_Category01:/ { print; getline; sub(/^[ \t]*\.quad[ \t]+l_OBJC_CLASS_NAME_$/, "\t.quad\tL_OBJC_IMAGE_INFO+3"); print; next } { print }' merge_cat_minimal.s > merge_cat_minimal_bad_name.s

# Assemble the modified source
# RUN: llvm-mc -filetype=obj -triple=arm64-apple-macos -o merge_cat_minimal_bad_name.o merge_cat_minimal_bad_name.s
Expand Down
1 change: 1 addition & 0 deletions lldb/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ interesting areas to contribute to lldb.
use/intel_pt
use/ondemand
use/aarch64-linux
use/symbolfilejson
use/troubleshooting
use/links
Man Page <man/lldb>
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_set_use_source_cache_false(self):
self.set_use_source_cache_and_test(False)

@skipIf(hostoslist=no_match(["windows"]))
@skipIf(oslist=["windows"]) # Fails on windows 11
def test_set_use_source_cache_true(self):
"""Test that after 'set use-source-cache false', files are locked."""
"""Test that after 'set use-source-cache true', files are locked."""
self.set_use_source_cache_and_test(True)

def set_use_source_cache_and_test(self, is_cache_enabled):
Expand All @@ -46,23 +45,27 @@ def set_use_source_cache_and_test(self, is_cache_enabled):
# Show the source file contents to make sure LLDB loads src file.
self.runCmd("source list")

# Try deleting the source file.
is_file_removed = self.removeFile(src)
# Try overwriting the source file.
is_file_overwritten = self.overwriteFile(src)

if is_cache_enabled:
self.assertFalse(
is_file_removed, "Source cache is enabled, but delete file succeeded"
is_file_overwritten,
"Source cache is enabled, but writing to file succeeded",
)

if not is_cache_enabled:
self.assertTrue(
is_file_removed, "Source cache is disabled, but delete file failed"
is_file_overwritten,
"Source cache is disabled, but writing to file failed",
)

def removeFile(self, src):
"""Remove file and return true iff file was successfully removed."""
def overwriteFile(self, src):
"""Write to file and return true iff file was successfully written."""
try:
os.remove(src)
f = open(src, "w")
f.writelines(["// hello world\n"])
f.close()
return True
except Exception:
return False
10 changes: 5 additions & 5 deletions lldb/test/API/python_api/find_in_memory/TestFindInMemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_find_in_memory_ok(self):
error = lldb.SBError()
addr = self.process.FindInMemory(
SINGLE_INSTANCE_PATTERN_STACK,
GetStackRange(self),
GetStackRange(self, True),
1,
error,
)
Expand All @@ -70,7 +70,7 @@ def test_find_in_memory_double_instance_ok(self):
error = lldb.SBError()
addr = self.process.FindInMemory(
DOUBLE_INSTANCE_PATTERN_HEAP,
GetHeapRanges(self)[0],
GetHeapRanges(self, True)[0],
1,
error,
)
Expand All @@ -86,7 +86,7 @@ def test_find_in_memory_invalid_alignment(self):
error = lldb.SBError()
addr = self.process.FindInMemory(
SINGLE_INSTANCE_PATTERN_STACK,
GetStackRange(self),
GetStackRange(self, True),
0,
error,
)
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_find_in_memory_invalid_buffer(self):
error = lldb.SBError()
addr = self.process.FindInMemory(
"",
GetStackRange(self),
GetStackRange(self, True),
1,
error,
)
Expand All @@ -131,7 +131,7 @@ def test_find_in_memory_unaligned(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)
error = lldb.SBError()
range = GetAlignedRange(self)
range = GetAlignedRange(self, True)

# First we make sure the pattern is found with alignment 1
addr = self.process.FindInMemory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_find_ranges_in_memory_two_matches(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetHeapRanges(self)
addr_ranges = GetHeapRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
DOUBLE_INSTANCE_PATTERN_HEAP,
Expand All @@ -48,7 +48,7 @@ def test_find_ranges_in_memory_one_match(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetStackRanges(self)
addr_ranges = GetStackRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
SINGLE_INSTANCE_PATTERN_STACK,
Expand All @@ -66,7 +66,7 @@ def test_find_ranges_in_memory_one_match_multiple_ranges(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetRanges(self)
addr_ranges = GetRanges(self, True)
addr_ranges.Append(lldb.SBAddressRange())
self.assertGreater(addr_ranges.GetSize(), 2)
error = lldb.SBError()
Expand All @@ -86,7 +86,7 @@ def test_find_ranges_in_memory_one_match_max(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetHeapRanges(self)
addr_ranges = GetHeapRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
DOUBLE_INSTANCE_PATTERN_HEAP,
Expand All @@ -104,7 +104,7 @@ def test_find_ranges_in_memory_invalid_alignment(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetHeapRanges(self)
addr_ranges = GetHeapRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
DOUBLE_INSTANCE_PATTERN_HEAP,
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_find_ranges_in_memory_invalid_buffer(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetHeapRanges(self)
addr_ranges = GetHeapRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
"",
Expand All @@ -178,7 +178,7 @@ def test_find_ranges_in_memory_invalid_max_matches(self):
self.assertTrue(self.process, PROCESS_IS_VALID)
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = GetHeapRanges(self)
addr_ranges = GetHeapRanges(self, True)
error = lldb.SBError()
matches = self.process.FindRangesInMemory(
DOUBLE_INSTANCE_PATTERN_HEAP,
Expand All @@ -197,7 +197,7 @@ def test_find_in_memory_unaligned(self):
self.assertState(self.process.GetState(), lldb.eStateStopped, PROCESS_STOPPED)

addr_ranges = lldb.SBAddressRangeList()
addr_ranges.Append(GetAlignedRange(self))
addr_ranges.Append(GetAlignedRange(self, True))
error = lldb.SBError()

matches = self.process.FindRangesInMemory(
Expand Down
50 changes: 32 additions & 18 deletions lldb/test/API/python_api/find_in_memory/address_ranges_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,30 @@
UNALIGNED_INSTANCE_PATTERN_HEAP = ALIGNED_INSTANCE_PATTERN_HEAP[1:]


def GetAlignedRange(test_base):
def GetAlignedRange(test_base, shrink=False):
frame = test_base.thread.GetSelectedFrame()
ex = frame.EvaluateExpression("aligned_string_ptr")
test_base.assertTrue(ex.IsValid())
return GetRangeFromAddrValue(test_base, ex)
return GetRangeFromAddrValue(test_base, ex, shrink)


def GetStackRange(test_base):
def GetStackRange(test_base, shrink=False):
frame = test_base.thread.GetSelectedFrame()
ex = frame.EvaluateExpression("&stack_pointer")
test_base.assertTrue(ex.IsValid())
return GetRangeFromAddrValue(test_base, ex)
return GetRangeFromAddrValue(test_base, ex, shrink)


def GetStackRanges(test_base):
def GetStackRanges(test_base, shrink=False):
addr_ranges = lldb.SBAddressRangeList()
addr_ranges.Append(GetStackRange(test_base))
return addr_ranges


def GetRangeFromAddrValue(test_base, addr):
def GetRangeFromAddrValue(test_base, addr, shrink=False):
"""Returns a memory region containing 'addr'.
If 'shrink' is True, the address range will be reduced to not exceed 2K.
"""
region = lldb.SBMemoryRegionInfo()
test_base.assertTrue(
test_base.process.GetMemoryRegionInfo(
Expand All @@ -37,37 +40,48 @@ def GetRangeFromAddrValue(test_base, addr):
test_base.assertTrue(region.IsReadable())
test_base.assertFalse(region.IsExecutable())

address_start = lldb.SBAddress(region.GetRegionBase(), test_base.target)
stack_size = region.GetRegionEnd() - region.GetRegionBase()
return lldb.SBAddressRange(address_start, stack_size)
base = region.GetRegionBase()
end = region.GetRegionEnd()

if shrink:
addr2 = addr.GetValueAsUnsigned()
addr2 -= addr2 % 512
base = max(base, addr2 - 1024)
end = min(end, addr2 + 1024)

def IsWithinRange(addr, range, target):
start = lldb.SBAddress(base, test_base.target)
size = end - base

return lldb.SBAddressRange(start, size)


def IsWithinRange(addr, size, range, target):
start_addr = range.GetBaseAddress().GetLoadAddress(target)
end_addr = start_addr + range.GetByteSize()
addr = addr.GetValueAsUnsigned()
return addr >= start_addr and addr < end_addr
return addr >= start_addr and addr + size <= end_addr


def GetHeapRanges(test_base):
def GetHeapRanges(test_base, shrink=False):
frame = test_base.thread.GetSelectedFrame()

ex = frame.EvaluateExpression("heap_pointer1")
test_base.assertTrue(ex.IsValid())
range = GetRangeFromAddrValue(test_base, ex)
range = GetRangeFromAddrValue(test_base, ex, shrink)
addr_ranges = lldb.SBAddressRangeList()
addr_ranges.Append(range)

ex = frame.EvaluateExpression("heap_pointer2")
test_base.assertTrue(ex.IsValid())
if not IsWithinRange(ex, addr_ranges[0], test_base.target):
addr_ranges.Append(GetRangeFromAddrValue(test_base, ex))
size = len(DOUBLE_INSTANCE_PATTERN_HEAP)
if not IsWithinRange(ex, size, addr_ranges[0], test_base.target):
addr_ranges.Append(GetRangeFromAddrValue(test_base, ex, shrink))

return addr_ranges


def GetRanges(test_base):
ranges = GetHeapRanges(test_base)
ranges.Append(GetStackRanges(test_base))
def GetRanges(test_base, shrink=False):
ranges = GetHeapRanges(test_base, shrink)
ranges.Append(GetStackRanges(test_base, shrink))

return ranges
4 changes: 3 additions & 1 deletion llvm/include/llvm-c/Disassembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ int LLVMSetDisasmOptions(LLVMDisasmContextRef DC, uint64_t Options);
#define LLVMDisassembler_Option_AsmPrinterVariant 4
/* The option to set comment on instructions */
#define LLVMDisassembler_Option_SetInstrComments 8
/* The option to print latency information alongside instructions */
/* The option to print latency information alongside instructions */
#define LLVMDisassembler_Option_PrintLatency 16
/* The option to print in color */
#define LLVMDisassembler_Option_Color 32

/**
* Dispose of a disassembler context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,15 @@ class DGNode {
virtual ~DGNode() = default;
/// \Returns the number of unscheduled successors.
unsigned getNumUnscheduledSuccs() const { return UnscheduledSuccs; }
void decrUnscheduledSuccs() {
assert(UnscheduledSuccs > 0 && "Counting error!");
--UnscheduledSuccs;
}
/// \Returns true if all dependent successors have been scheduled.
bool ready() const { return UnscheduledSuccs == 0; }
/// \Returns true if this node has been scheduled.
bool scheduled() const { return Scheduled; }
void setScheduled(bool NewVal) { Scheduled = NewVal; }
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
using iterator = PredIterator;
Expand Down
126 changes: 126 additions & 0 deletions llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
//===- Scheduler.h ----------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the bottom-up list scheduler used by the vectorizer. It is used for
// checking the legality of vectorization and for scheduling instructions in
// such a way that makes vectorization possible, if legal.
//
// The legality check is performed by `trySchedule(Instrs)`, which will try to
// schedule the IR until all instructions in `Instrs` can be scheduled together
// back-to-back. If this fails then it is illegal to vectorize `Instrs`.
//
// Internally the scheduler uses the vectorizer-specific DependencyGraph class.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H

#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
#include <queue>

namespace llvm::sandboxir {

class PriorityCmp {
public:
bool operator()(const DGNode *N1, const DGNode *N2) {
// TODO: This should be a hierarchical comparator.
return N1->getInstruction()->comesBefore(N2->getInstruction());
}
};

/// The list holding nodes that are ready to schedule. Used by the scheduler.
class ReadyListContainer {
PriorityCmp Cmp;
/// Control/Other dependencies are not modeled by the DAG to save memory.
/// These have to be modeled in the ready list for correctness.
/// This means that the list will hold back nodes that need to meet such
/// unmodeled dependencies.
std::priority_queue<DGNode *, std::vector<DGNode *>, PriorityCmp> List;

public:
ReadyListContainer() : List(Cmp) {}
void insert(DGNode *N) { List.push(N); }
DGNode *pop() {
auto *Back = List.top();
List.pop();
return Back;
}
bool empty() const { return List.empty(); }
#ifndef NDEBUG
void dump(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
#endif // NDEBUG
};

/// The nodes that need to be scheduled back-to-back in a single scheduling
/// cycle form a SchedBundle.
class SchedBundle {
public:
using ContainerTy = SmallVector<DGNode *, 4>;

private:
ContainerTy Nodes;

public:
SchedBundle() = default;
SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
using iterator = ContainerTy::iterator;
using const_iterator = ContainerTy::const_iterator;
iterator begin() { return Nodes.begin(); }
iterator end() { return Nodes.end(); }
const_iterator begin() const { return Nodes.begin(); }
const_iterator end() const { return Nodes.end(); }
/// \Returns the bundle node that comes before the others in program order.
DGNode *getTop() const;
/// \Returns the bundle node that comes after the others in program order.
DGNode *getBot() const;
/// Move all bundle instructions to \p Where back-to-back.
void cluster(BasicBlock::iterator Where);
#ifndef NDEBUG
void dump(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
#endif
};

/// The list scheduler.
class Scheduler {
ReadyListContainer ReadyList;
DependencyGraph DAG;
std::optional<BasicBlock::iterator> ScheduleTopItOpt;
SmallVector<std::unique_ptr<SchedBundle>> Bndls;

/// \Returns a scheduling bundle containing \p Instrs.
SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
/// Schedule nodes until we can schedule \p Instrs back-to-back.
bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
/// Schedules all nodes in \p Bndl, marks them as scheduled, updates the
/// UnscheduledSuccs counter of all dependency predecessors, and adds any of
/// them that become ready to the ready list.
void scheduleAndUpdateReadyList(SchedBundle &Bndl);

/// Disable copies.
Scheduler(const Scheduler &) = delete;
Scheduler &operator=(const Scheduler &) = delete;

public:
Scheduler(AAResults &AA) : DAG(AA) {}
~Scheduler() {}

bool trySchedule(ArrayRef<Instruction *> Instrs);

#ifndef NDEBUG
void dump(raw_ostream &OS) const;
LLVM_DUMP_METHOD void dump() const;
#endif
};

} // namespace llvm::sandboxir

#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_SCHEDULER_H
11 changes: 11 additions & 0 deletions llvm/lib/MC/MCDisassembler/Disassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ size_t LLVMDisasmInstruction(LLVMDisasmContextRef DCR, uint8_t *Bytes,
SmallVector<char, 64> InsnStr;
raw_svector_ostream OS(InsnStr);
formatted_raw_ostream FormattedOS(OS);

if (DC->getOptions() & LLVMDisassembler_Option_Color) {
FormattedOS.enable_colors(true);
IP->setUseColor(true);
}

IP->printInst(&Inst, PC, AnnotationsStr, *DC->getSubtargetInfo(),
FormattedOS);

Expand Down Expand Up @@ -343,5 +349,10 @@ int LLVMSetDisasmOptions(LLVMDisasmContextRef DCR, uint64_t Options){
DC->addOptions(LLVMDisassembler_Option_PrintLatency);
Options &= ~LLVMDisassembler_Option_PrintLatency;
}
if (Options & LLVMDisassembler_Option_Color) {
LLVMDisasmContext *DC = static_cast<LLVMDisasmContext *>(DCR);
DC->addOptions(LLVMDisassembler_Option_Color);
Options &= ~LLVMDisassembler_Option_Color;
}
return (Options == 0);
}
6 changes: 3 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20385,11 +20385,11 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
if (VT.isVector())
break;
if (VT == MVT::f16 && Subtarget.hasStdExtZhinxmin())
return std::make_pair(0U, &RISCV::GPRF16RegClass);
return std::make_pair(0U, &RISCV::GPRF16NoX0RegClass);
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32RegClass);
return std::make_pair(0U, &RISCV::GPRF32NoX0RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRPairRegClass);
return std::make_pair(0U, &RISCV::GPRPairNoX0RegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVRegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def GPRF16 : RISCVRegisterClass<[f16], 16, (add (sequence "X%u_H", 10, 17),
(sequence "X%u_H", 0, 4))>;
def GPRF16C : RISCVRegisterClass<[f16], 16, (add (sequence "X%u_H", 10, 15),
(sequence "X%u_H", 8, 9))>;
def GPRF16NoX0 : RISCVRegisterClass<[f16], 16, (sub GPRF16, X0_H)>;

def GPRF32 : RISCVRegisterClass<[f32], 32, (add (sequence "X%u_W", 10, 17),
(sequence "X%u_W", 5, 7),
Expand Down Expand Up @@ -721,6 +722,8 @@ def GPRPair : RISCVRegisterClass<[XLenPairFVT], 64, (add
def GPRPairC : RISCVRegisterClass<[XLenPairFVT], 64, (add
X10_X11, X12_X13, X14_X15, X8_X9
)>;

def GPRPairNoX0 : RISCVRegisterClass<[XLenPairFVT], 64, (sub GPRPair, X0_Pair)>;
} // let RegInfos = XLenPairRI, DecoderMethod = "DecodeGPRPairRegisterClass"

// The register class is added for inline assembly for vector mask types.
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_llvm_component_library(LLVMVectorize
SandboxVectorizer/Passes/RegionsFromMetadata.cpp
SandboxVectorizer/SandboxVectorizer.cpp
SandboxVectorizer/SandboxVectorizerPassBuilder.cpp
SandboxVectorizer/Scheduler.cpp
SandboxVectorizer/SeedCollector.cpp
SLPVectorizer.cpp
Vectorize.cpp
Expand Down
38 changes: 31 additions & 7 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2447,12 +2447,26 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
};

TailFoldingStyle Style = Cost->getTailFoldingStyle();
if (Style == TailFoldingStyle::None)
CheckMinIters =
Builder.CreateICmp(P, Count, CreateStep(), "min.iters.check");
else if (VF.isScalable() &&
!isIndvarOverflowCheckKnownFalse(Cost, VF, UF) &&
Style != TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
if (Style == TailFoldingStyle::None) {
Value *Step = CreateStep();
ScalarEvolution &SE = *PSE.getSE();
// TODO: Emit unconditional branch to vector preheader instead of
// conditional branch with known condition.
const SCEV *TripCountSCEV = SE.applyLoopGuards(SE.getSCEV(Count), OrigLoop);
// Check if the trip count is < the step.
if (SE.isKnownPredicate(P, TripCountSCEV, SE.getSCEV(Step))) {
// TODO: Ensure step is at most the trip count when determining max VF and
// UF, w/o tail folding.
CheckMinIters = Builder.getTrue();
} else if (!SE.isKnownPredicate(CmpInst::getInversePredicate(P),
TripCountSCEV, SE.getSCEV(Step))) {
// Generate the minimum iteration check only if we cannot prove the
// check is known to be true, or known to be false.
CheckMinIters = Builder.CreateICmp(P, Count, Step, "min.iters.check");
} // else step known to be < trip count, use CheckMinIters preset to false.
} else if (VF.isScalable() &&
!isIndvarOverflowCheckKnownFalse(Cost, VF, UF) &&
Style != TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck) {
// vscale is not necessarily a power-of-2, which means we cannot guarantee
// an overflow to zero when updating induction variables and so an
// additional overflow check is required before entering the vector loop.
Expand All @@ -2462,8 +2476,18 @@ void InnerLoopVectorizer::emitIterationCountCheck(BasicBlock *Bypass) {
ConstantInt::get(CountTy, cast<IntegerType>(CountTy)->getMask());
Value *LHS = Builder.CreateSub(MaxUIntTripCount, Count);

Value *Step = CreateStep();
#ifndef NDEBUG
ScalarEvolution &SE = *PSE.getSE();
const SCEV *TC2OverflowSCEV = SE.applyLoopGuards(SE.getSCEV(LHS), OrigLoop);
assert(
!isIndvarOverflowCheckKnownFalse(Cost, VF * UF) &&
!SE.isKnownPredicate(CmpInst::getInversePredicate(ICmpInst::ICMP_ULT),
TC2OverflowSCEV, SE.getSCEV(Step)) &&
"unexpectedly proved overflow check to be known");
#endif
// Don't execute the vector loop if (UMax - n) < (VF * UF).
CheckMinIters = Builder.CreateICmp(ICmpInst::ICMP_ULT, LHS, CreateStep());
CheckMinIters = Builder.CreateICmp(ICmpInst::ICMP_ULT, LHS, Step);
}

// Create new preheader for vector loop.
Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14953,6 +14953,12 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
return E->VectorizedValue;
}

if (Op->getType() != VecTy) {
assert((It != MinBWs.end() || getOperandEntry(E, 0)->isGather() ||
MinBWs.contains(getOperandEntry(E, 0))) &&
"Expected item in MinBWs.");
Op = Builder.CreateIntCast(Op, VecTy, GetOperandSignedness(0));
}
Value *V = Builder.CreateFreeze(Op);
V = FinalShuffle(V, E);

Expand Down Expand Up @@ -17095,6 +17101,8 @@ bool BoUpSLP::collectValuesToDemote(
return TryProcessInstruction(
BitWidth, {getOperandEntry(&E, 0), getOperandEntry(&E, 1)});
}
case Instruction::Freeze:
return TryProcessInstruction(BitWidth, getOperandEntry(&E, 0));
case Instruction::Shl: {
// If we are truncating the result of this SHL, and if it's a shift of an
// inrange amount, we can always perform a SHL in a smaller type.
Expand Down Expand Up @@ -17216,9 +17224,25 @@ bool BoUpSLP::collectValuesToDemote(
MaskedValueIsZero(I->getOperand(1), Mask, SimplifyQuery(*DL)));
});
};
auto AbsChecker = [&](unsigned BitWidth, unsigned OrigBitWidth) {
assert(BitWidth <= OrigBitWidth && "Unexpected bitwidths!");
return all_of(E.Scalars, [&](Value *V) {
auto *I = cast<Instruction>(V);
unsigned SignBits = OrigBitWidth - BitWidth;
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth - 1);
unsigned Op0SignBits =
ComputeNumSignBits(I->getOperand(0), *DL, 0, AC, nullptr, DT);
return SignBits <= Op0SignBits &&
((SignBits != Op0SignBits &&
!isKnownNonNegative(I->getOperand(0), SimplifyQuery(*DL))) ||
MaskedValueIsZero(I->getOperand(0), Mask, SimplifyQuery(*DL)));
});
};
if (ID != Intrinsic::abs) {
Operands.push_back(getOperandEntry(&E, 1));
CallChecker = CompChecker;
} else {
CallChecker = AbsChecker;
}
InstructionCost BestCost =
std::numeric_limits<InstructionCost::CostType>::max();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ bool PredIterator::operator==(const PredIterator &Other) const {

#ifndef NDEBUG
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
OS << *I << " USuccs:" << UnscheduledSuccs << "\n";
OS << *I << " USuccs:" << UnscheduledSuccs << " Sched:" << Scheduled << "\n";
}
void DGNode::dump() const { print(dbgs()); }
void MemDGNode::print(raw_ostream &OS, bool PrintDeps) const {
Expand Down Expand Up @@ -249,6 +249,10 @@ void DependencyGraph::setDefUseUnscheduledSuccs(
// Walk over all instructions in "BotInterval" and update the counter
// of operands that are in "TopInterval".
for (Instruction &BotI : BotInterval) {
auto *BotN = getNode(&BotI);
// Skip scheduled nodes.
if (BotN->scheduled())
continue;
for (Value *Op : BotI.operands()) {
auto *OpI = dyn_cast<Instruction>(Op);
if (OpI == nullptr)
Expand Down Expand Up @@ -286,7 +290,9 @@ void DependencyGraph::createNewNodes(const Interval<Instruction> &NewInterval) {
MemDGNodeIntervalBuilder::getBotMemDGNode(TopInterval, *this);
MemDGNode *LinkBotN =
MemDGNodeIntervalBuilder::getTopMemDGNode(BotInterval, *this);
assert(LinkTopN->comesBefore(LinkBotN) && "Wrong order!");
assert((LinkTopN == nullptr || LinkBotN == nullptr ||
LinkTopN->comesBefore(LinkBotN)) &&
"Wrong order!");
if (LinkTopN != nullptr && LinkBotN != nullptr) {
LinkTopN->setNextNode(LinkBotN);
LinkBotN->setPrevNode(LinkTopN);
Expand Down
169 changes: 169 additions & 0 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
//===- Scheduler.cpp ------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"

namespace llvm::sandboxir {

// TODO: Check if we can cache top/bottom to reduce compile-time.
DGNode *SchedBundle::getTop() const {
DGNode *TopN = Nodes.front();
for (auto *N : drop_begin(Nodes)) {
if (N->getInstruction()->comesBefore(TopN->getInstruction()))
TopN = N;
}
return TopN;
}

DGNode *SchedBundle::getBot() const {
DGNode *BotN = Nodes.front();
for (auto *N : drop_begin(Nodes)) {
if (BotN->getInstruction()->comesBefore(N->getInstruction()))
BotN = N;
}
return BotN;
}

void SchedBundle::cluster(BasicBlock::iterator Where) {
for (auto *N : Nodes) {
auto *I = N->getInstruction();
if (I->getIterator() == Where)
++Where; // Try to maintain bundle order.
I->moveBefore(*Where.getNodeParent(), Where);
}
}

#ifndef NDEBUG
void SchedBundle::dump(raw_ostream &OS) const {
for (auto *N : Nodes)
OS << *N;
}

void SchedBundle::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

#ifndef NDEBUG
void ReadyListContainer::dump(raw_ostream &OS) const {
auto ListCopy = List;
while (!ListCopy.empty()) {
OS << *ListCopy.top() << "\n";
ListCopy.pop();
}
}

void ReadyListContainer::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

void Scheduler::scheduleAndUpdateReadyList(SchedBundle &Bndl) {
// Find where we should schedule the instructions.
assert(ScheduleTopItOpt && "Should have been set by now!");
auto Where = *ScheduleTopItOpt;
// Move all instructions in `Bndl` to `Where`.
Bndl.cluster(Where);
// Update the last scheduled bundle.
ScheduleTopItOpt = Bndl.getTop()->getInstruction()->getIterator();
// Set nodes as "scheduled" and decrement the UnsceduledSuccs counter of all
// dependency predecessors.
for (DGNode *N : Bndl) {
N->setScheduled(true);
for (auto *DepN : N->preds(DAG)) {
// TODO: preds() should not return nullptr.
if (DepN == nullptr)
continue;
DepN->decrUnscheduledSuccs();
if (DepN->ready())
ReadyList.insert(DepN);
}
}
}

SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
SchedBundle::ContainerTy Nodes;
Nodes.reserve(Instrs.size());
for (auto *I : Instrs)
Nodes.push_back(DAG.getNode(I));
auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
auto *Bndl = BndlPtr.get();
Bndls.push_back(std::move(BndlPtr));
return Bndl;
}

bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
// Use a set of instructions, instead of `Instrs` for fast lookups.
DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
// This collects the nodes that correspond to instructions found in `Instrs`
// that have just become ready. These nodes won't be scheduled right away.
SmallVector<DGNode *, 8> DeferredNodes;

// Keep scheduling ready nodes until we either run out of ready nodes (i.e.,
// ReadyList is empty), or all nodes that correspond to `Instrs` (the nodes of
// which are collected in DeferredNodes) are all ready to schedule.
while (!ReadyList.empty()) {
auto *ReadyN = ReadyList.pop();
if (InstrsToDefer.contains(ReadyN->getInstruction())) {
// If the ready instruction is one of those in `Instrs`, then we don't
// schedule it right away. Instead we defer it until we can schedule it
// along with the rest of the instructions in `Instrs`, at the same
// time in a single scheduling bundle.
DeferredNodes.push_back(ReadyN);
bool ReadyToScheduleDeferred = DeferredNodes.size() == Instrs.size();
if (ReadyToScheduleDeferred) {
scheduleAndUpdateReadyList(*createBundle(Instrs));
return true;
}
} else {
// If the ready instruction is not found in `Instrs`, then we wrap it in a
// scheduling bundle and schedule it right away.
scheduleAndUpdateReadyList(*createBundle({ReadyN->getInstruction()}));
}
}
assert(DeferredNodes.size() != Instrs.size() &&
"We should have succesfully scheduled and early-returned!");
return false;
}

bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
assert(all_of(drop_begin(Instrs),
[Instrs](Instruction *I) {
return I->getParent() == (*Instrs.begin())->getParent();
}) &&
"Instrs not in the same BB!");
// Extend the DAG to include Instrs.
Interval<Instruction> Extension = DAG.extend(Instrs);
// TODO: Set the window of the DAG that we are interested in.
// We start scheduling at the bottom instr of Instrs.
auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
return *min_element(Instrs,
[](auto *I1, auto *I2) { return I1->comesBefore(I2); });
};
ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
// Add nodes to ready list.
for (auto &I : Extension) {
auto *N = DAG.getNode(&I);
if (N->ready())
ReadyList.insert(N);
}
// Try schedule all nodes until we can schedule Instrs back-to-back.
return tryScheduleUntil(Instrs);
}

#ifndef NDEBUG
void Scheduler::dump(raw_ostream &OS) const {
OS << "ReadyList:\n";
ReadyList.dump(OS);
}
void Scheduler::dump() const { dump(dbgs()); }
#endif // NDEBUG

} // namespace llvm::sandboxir
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ define void @f1(ptr %A) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
Expand Down
44 changes: 9 additions & 35 deletions llvm/test/Transforms/LoopVectorize/AArch64/masked-call.ll
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ target triple = "aarch64-unknown-linux-gnu"
define void @test_widen(ptr noalias %a, ptr readnone %b) #4 {
; TFNONE-LABEL: @test_widen(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down Expand Up @@ -146,10 +143,7 @@ for.cond.cleanup:
define void @test_if_then(ptr noalias %a, ptr readnone %b) #4 {
; TFNONE-LABEL: @test_if_then(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down Expand Up @@ -310,10 +304,7 @@ for.cond.cleanup:
define void @test_widen_if_then_else(ptr noalias %a, ptr readnone %b) #4 {
; TFNONE-LABEL: @test_widen_if_then_else(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down Expand Up @@ -490,10 +481,7 @@ for.cond.cleanup:
define void @test_widen_nomask(ptr noalias %a, ptr readnone %b) #4 {
; TFNONE-LABEL: @test_widen_nomask(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down Expand Up @@ -548,11 +536,6 @@ define void @test_widen_nomask(ptr noalias %a, ptr readnone %b) #4 {
;
; TFFALLBACK-LABEL: @test_widen_nomask(
; TFFALLBACK-NEXT: entry:
; TFFALLBACK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFFALLBACK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFFALLBACK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFFALLBACK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFFALLBACK: vector.ph:
; TFFALLBACK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFFALLBACK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
; TFFALLBACK-NEXT: [[N_MOD_VF:%.*]] = urem i64 1025, [[TMP3]]
Expand All @@ -561,20 +544,17 @@ define void @test_widen_nomask(ptr noalias %a, ptr readnone %b) #4 {
; TFFALLBACK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 2
; TFFALLBACK-NEXT: br label [[VECTOR_BODY:%.*]]
; TFFALLBACK: vector.body:
; TFFALLBACK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; TFFALLBACK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
; TFFALLBACK-NEXT: [[TMP6:%.*]] = getelementptr i64, ptr [[B:%.*]], i64 [[INDEX]]
; TFFALLBACK-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 2 x i64>, ptr [[TMP6]], align 8
; TFFALLBACK-NEXT: [[TMP7:%.*]] = call <vscale x 2 x i64> @foo_vector_nomask(<vscale x 2 x i64> [[WIDE_LOAD]])
; TFFALLBACK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[A:%.*]], i64 [[INDEX]]
; TFFALLBACK-NEXT: store <vscale x 2 x i64> [[TMP7]], ptr [[TMP8]], align 8
; TFFALLBACK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
; TFFALLBACK-NEXT: [[TMP9:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
; TFFALLBACK-NEXT: br i1 [[TMP9]], label [[SCALAR_PH]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; TFFALLBACK: scalar.ph:
; TFFALLBACK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[N_VEC]], [[VECTOR_BODY]] ]
; TFFALLBACK-NEXT: br label [[FOR_BODY:%.*]]
; TFFALLBACK-NEXT: br i1 [[TMP9]], label [[FOR_BODY:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP5:![0-9]+]]
; TFFALLBACK: for.body:
; TFFALLBACK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ]
; TFFALLBACK-NEXT: [[INDVARS_IV:%.*]] = phi i64 [ [[INDVARS_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[N_VEC]], [[VECTOR_BODY]] ]
; TFFALLBACK-NEXT: [[GEP:%.*]] = getelementptr i64, ptr [[B]], i64 [[INDVARS_IV]]
; TFFALLBACK-NEXT: [[LOAD:%.*]] = load i64, ptr [[GEP]], align 8
; TFFALLBACK-NEXT: [[CALL:%.*]] = call i64 @foo(i64 [[LOAD]]) #[[ATTR5:[0-9]+]]
Expand Down Expand Up @@ -626,10 +606,7 @@ for.cond.cleanup:
define void @test_widen_optmask(ptr noalias %a, ptr readnone %b) #4 {
; TFNONE-LABEL: @test_widen_optmask(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down Expand Up @@ -791,10 +768,7 @@ for.cond.cleanup:
define double @test_widen_fmuladd_and_call(ptr noalias %a, ptr readnone %b, double %m) #4 {
; TFNONE-LABEL: @test_widen_fmuladd_and_call(
; TFNONE-NEXT: entry:
; TFNONE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 2
; TFNONE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; TFNONE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; TFNONE: vector.ph:
; TFNONE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; TFNONE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ define void @test_invar_gep(ptr %dst) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 100, [[TMP1]]
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,7 @@ define void @simple_memset_trip1024(i32 %val, ptr %ptr, i64 %n) #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1024, [[TMP1]]
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ target triple = "aarch64-unknown-linux-gnu"
define void @test_widen(ptr noalias %a, ptr readnone %b) #1 {
; WIDE-LABEL: @test_widen(
; WIDE-NEXT: entry:
; WIDE-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; WIDE-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
; WIDE-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 1025, [[TMP1]]
; WIDE-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; WIDE-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; WIDE: vector.ph:
; WIDE-NEXT: [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
; WIDE-NEXT: [[TMP3:%.*]] = mul i64 [[TMP2]], 4
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/LoopVectorize/if-reduction.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,7 @@ for.end: ; preds = %for.body, %entry
ret i64 %1
}

; FIXME: %indvars.iv.next is poison on first iteration due to sub nuw 0, 1.
define i32 @fcmp_0_sub_select1(ptr noalias %x, i32 %N) nounwind readonly {
; CHECK-LABEL: define i32 @fcmp_0_sub_select1(
; CHECK-SAME: ptr noalias [[X:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
Expand All @@ -1668,8 +1669,7 @@ define i32 @fcmp_0_sub_select1(ptr noalias %x, i32 %N) nounwind readonly {
; CHECK: [[FOR_HEADER]]:
; CHECK-NEXT: [[ZEXT:%.*]] = zext i32 [[N]] to i64
; CHECK-NEXT: [[TMP0:%.*]] = sub i64 0, [[ZEXT]]
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP0]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP0]], 4
; CHECK-NEXT: [[N_VEC:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ exit:

; Test case to make sure that uses of versioned strides of type i1 are properly
; extended. From https://github.com/llvm/llvm-project/issues/91369.
; TODO: Better check (udiv i64 15, %g.64) after checking if %g == 1.
define void @zext_of_i1_stride(i1 %g, ptr %dst) mustprogress {
; CHECK-LABEL: define void @zext_of_i1_stride(
; CHECK-SAME: i1 [[G:%.*]], ptr [[DST:%.*]]) #[[ATTR0:[0-9]+]] {
Expand All @@ -423,8 +424,7 @@ define void @zext_of_i1_stride(i1 %g, ptr %dst) mustprogress {
; CHECK-NEXT: [[G_64:%.*]] = zext i1 [[G]] to i64
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 15, [[G_64]]
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[TMP0]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
; CHECK-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
; CHECK: vector.scevcheck:
; CHECK-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i1 [[G]], true
; CHECK-NEXT: br i1 [[IDENT_CHECK]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -489,8 +489,7 @@ define void @sext_of_i1_stride(i1 %g, ptr %dst) mustprogress {
; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[UMAX]], -1
; CHECK-NEXT: [[TMP1:%.*]] = udiv i64 [[TMP0]], [[G_64]]
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP2]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
; CHECK-NEXT: br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_SCEVCHECK:%.*]]
; CHECK: vector.scevcheck:
; CHECK-NEXT: [[IDENT_CHECK:%.*]] = icmp ne i1 [[G]], true
; CHECK-NEXT: br i1 [[IDENT_CHECK]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S --passes=slp-vectorizer < %s | FileCheck %s

define i32 @test(i32 %n) {
; CHECK-LABEL: define i32 @test(
; CHECK-SAME: i32 [[N:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = insertelement <2 x i32> poison, i32 [[N]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <2 x i32> [[TMP0]], <2 x i32> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i32> [[TMP1]], <i32 1, i32 2>
; CHECK-NEXT: [[TMP3:%.*]] = zext <2 x i32> [[TMP2]] to <2 x i64>
; CHECK-NEXT: [[TMP7:%.*]] = mul nuw nsw <2 x i64> [[TMP3]], <i64 273837369, i64 273837369>
; CHECK-NEXT: [[TMP8:%.*]] = call <2 x i64> @llvm.abs.v2i64(<2 x i64> [[TMP7]], i1 true)
; CHECK-NEXT: [[TMP4:%.*]] = trunc <2 x i64> [[TMP8]] to <2 x i32>
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x i32> [[TMP4]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <2 x i32> [[TMP4]], i32 1
; CHECK-NEXT: [[RES1:%.*]] = add i32 [[TMP5]], [[TMP6]]
; CHECK-NEXT: ret i32 [[RES1]]
;
entry:
%n1 = add i32 %n, 1
%zn1 = zext nneg i32 %n1 to i64
%m1 = mul nuw nsw i64 %zn1, 273837369
%a1 = call i64 @llvm.abs.i64(i64 %m1, i1 true)
%t1 = trunc i64 %a1 to i32
%n2 = add i32 %n, 2
%zn2 = zext nneg i32 %n2 to i64
%m2 = mul nuw nsw i64 %zn2, 273837369
%a2 = call i64 @llvm.abs.i64(i64 %m2, i1 true)
%t2 = trunc i64 %a2 to i32
%res1 = add i32 %t1, %t2
ret i32 %res1
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ define i32 @test(i1 %.b, i8 %conv18, i32 %k.promoted61) {
; CHECK-NEXT: [[TMP3:%.*]] = xor <4 x i1> [[TMP2]], <i1 true, i1 true, i1 true, i1 true>
; CHECK-NEXT: [[TMP4:%.*]] = zext <4 x i1> [[TMP3]] to <4 x i8>
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq <4 x i8> [[TMP4]], zeroinitializer
; CHECK-NEXT: [[TMP6:%.*]] = freeze <4 x i1> [[TMP3]]
; CHECK-NEXT: [[TMP7:%.*]] = sext <4 x i1> [[TMP6]] to <4 x i8>
; CHECK-NEXT: [[TMP6:%.*]] = zext <4 x i1> [[TMP3]] to <4 x i8>
; CHECK-NEXT: [[TMP7:%.*]] = freeze <4 x i8> [[TMP6]]
; CHECK-NEXT: [[TMP8:%.*]] = insertelement <4 x i8> poison, i8 [[CONV18]], i32 0
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x i8> [[TMP8]], <4 x i8> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP10:%.*]] = icmp ugt <4 x i8> [[TMP7]], [[TMP9]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ add_llvm_unittest(SandboxVectorizerTests
DependencyGraphTest.cpp
IntervalTest.cpp
LegalityTest.cpp
SchedulerTest.cpp
SeedCollectorTest.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,18 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_EQ(N0->getNumUnscheduledSuccs(), 1u); // N1
EXPECT_EQ(N1->getNumUnscheduledSuccs(), 0u);
EXPECT_EQ(N2->getNumUnscheduledSuccs(), 0u);

// Check decrUnscheduledSuccs.
N0->decrUnscheduledSuccs();
EXPECT_EQ(N0->getNumUnscheduledSuccs(), 0u);
#ifndef NDEBUG
EXPECT_DEATH(N0->decrUnscheduledSuccs(), ".*Counting.*");
#endif // NDEBUG

// Check scheduled(), setScheduled().
EXPECT_FALSE(N0->scheduled());
N0->setScheduled(true);
EXPECT_TRUE(N0->scheduled());
}

TEST_F(DependencyGraphTest, Preds) {
Expand Down Expand Up @@ -773,4 +785,16 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %v3, i8 %v4, i8 %v5) {
EXPECT_EQ(S4N->getNumUnscheduledSuccs(), 1u); // S5N
EXPECT_EQ(S5N->getNumUnscheduledSuccs(), 0u);
}

{
// Check UnscheduledSuccs when a node is scheduled
sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({S2, S2});
auto *S2N = cast<sandboxir::MemDGNode>(DAG.getNode(S2));
S2N->setScheduled(true);

DAG.extend({S1, S1});
auto *S1N = cast<sandboxir::MemDGNode>(DAG.getNode(S1));
EXPECT_EQ(S1N->getNumUnscheduledSuccs(), 0u); // S1 is scheduled
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
//===- SchedulerTest.cpp --------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Dominators.h"
#include "llvm/SandboxIR/Context.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Support/SourceMgr.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"

using namespace llvm;

struct SchedulerTest : public testing::Test {
LLVMContext C;
std::unique_ptr<Module> M;
std::unique_ptr<AssumptionCache> AC;
std::unique_ptr<DominatorTree> DT;
std::unique_ptr<BasicAAResult> BAA;
std::unique_ptr<AAResults> AA;

void parseIR(LLVMContext &C, const char *IR) {
SMDiagnostic Err;
M = parseAssemblyString(IR, Err, C);
if (!M)
Err.print("SchedulerTest", errs());
}

AAResults &getAA(llvm::Function &LLVMF) {
TargetLibraryInfoImpl TLII;
TargetLibraryInfo TLI(TLII);
AA = std::make_unique<AAResults>(TLI);
AC = std::make_unique<AssumptionCache>(LLVMF);
DT = std::make_unique<DominatorTree>(LLVMF);
BAA = std::make_unique<BasicAAResult>(M->getDataLayout(), LLVMF, TLI, *AC,
DT.get());
AA->addAAResult(*BAA);
return *AA;
}
};

TEST_F(SchedulerTest, SchedBundle) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
store i8 %v0, ptr %ptr
%other = add i8 %v0, %v1
store i8 %v1, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *Other = &*It++;
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::DependencyGraph DAG(getAA(*LLVMF));
DAG.extend({&*BB->begin(), BB->getTerminator()});
auto *SN0 = DAG.getNode(S0);
auto *SN1 = DAG.getNode(S1);
sandboxir::SchedBundle Bndl({SN0, SN1});

// Check getTop().
EXPECT_EQ(Bndl.getTop(), SN0);
// Check getBot().
EXPECT_EQ(Bndl.getBot(), SN1);
// Check cluster().
Bndl.cluster(S1->getIterator());
{
auto It = BB->begin();
EXPECT_EQ(&*It++, Other);
EXPECT_EQ(&*It++, S0);
EXPECT_EQ(&*It++, S1);
EXPECT_EQ(&*It++, Ret);
S0->moveBefore(Other);
}

Bndl.cluster(S0->getIterator());
{
auto It = BB->begin();
EXPECT_EQ(&*It++, S0);
EXPECT_EQ(&*It++, S1);
EXPECT_EQ(&*It++, Other);
EXPECT_EQ(&*It++, Ret);
S1->moveAfter(Other);
}

Bndl.cluster(Other->getIterator());
{
auto It = BB->begin();
EXPECT_EQ(&*It++, S0);
EXPECT_EQ(&*It++, S1);
EXPECT_EQ(&*It++, Other);
EXPECT_EQ(&*It++, Ret);
S1->moveAfter(Other);
}

Bndl.cluster(Ret->getIterator());
{
auto It = BB->begin();
EXPECT_EQ(&*It++, Other);
EXPECT_EQ(&*It++, S0);
EXPECT_EQ(&*It++, S1);
EXPECT_EQ(&*It++, Ret);
Other->moveBefore(S1);
}

Bndl.cluster(BB->end());
{
auto It = BB->begin();
EXPECT_EQ(&*It++, Other);
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(&*It++, S0);
EXPECT_EQ(&*It++, S1);
Ret->moveAfter(S1);
Other->moveAfter(S0);
}
// Check iterators.
EXPECT_THAT(Bndl, testing::ElementsAre(SN0, SN1));
EXPECT_THAT((const sandboxir::SchedBundle &)Bndl,
testing::ElementsAre(SN0, SN1));
}

TEST_F(SchedulerTest, Basic) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
store i8 %v0, ptr %ptr
store i8 %v1, ptr %ptr
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

{
// Schedule all instructions in sequence.
sandboxir::Scheduler Sched(getAA(*LLVMF));
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S1}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Skip instructions.
sandboxir::Scheduler Sched(getAA(*LLVMF));
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
}
{
// Try invalid scheduling
sandboxir::Scheduler Sched(getAA(*LLVMF));
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0}));
EXPECT_FALSE(Sched.trySchedule({S1}));
}
}

TEST_F(SchedulerTest, Bundles) {
parseIR(C, R"IR(
define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
%ld0 = load i8, ptr %ptr0
%ld1 = load i8, ptr %ptr1
store i8 %ld0, ptr %ptr0
store i8 %ld1, ptr %ptr1
ret void
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(LLVMF);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *L0 = cast<sandboxir::LoadInst>(&*It++);
auto *L1 = cast<sandboxir::LoadInst>(&*It++);
auto *S0 = cast<sandboxir::StoreInst>(&*It++);
auto *S1 = cast<sandboxir::StoreInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

sandboxir::Scheduler Sched(getAA(*LLVMF));
EXPECT_TRUE(Sched.trySchedule({Ret}));
EXPECT_TRUE(Sched.trySchedule({S0, S1}));
EXPECT_TRUE(Sched.trySchedule({L0, L1}));
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ static_library("Vectorize") {
"SandboxVectorizer/Passes/RegionsFromMetadata.cpp",
"SandboxVectorizer/SandboxVectorizer.cpp",
"SandboxVectorizer/SandboxVectorizerPassBuilder.cpp",
"SandboxVectorizer/Scheduler.cpp",
"SandboxVectorizer/SeedCollector.cpp",
"VPlan.cpp",
"VPlanAnalysis.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ unittest("SandboxVectorizerTests") {
"DependencyGraphTest.cpp",
"IntervalTest.cpp",
"LegalityTest.cpp",
"SchedulerTest.cpp",
]
}
2 changes: 1 addition & 1 deletion mlir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ endif()

# Must go below project(..)
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "C++ standard to conform to")

if(MLIR_STANDALONE_BUILD)
find_package(LLVM CONFIG REQUIRED)
Expand Down
81 changes: 41 additions & 40 deletions mlir/lib/TableGen/AttrOrTypeDef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DefInit;
using llvm::Init;
using llvm::ListInit;
using llvm::Record;
using llvm::RecordVal;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// AttrOrTypeBuilder
Expand All @@ -35,14 +41,13 @@ bool AttrOrTypeBuilder::hasInferredContextParameter() const {
// AttrOrTypeDef
//===----------------------------------------------------------------------===//

AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
AttrOrTypeDef::AttrOrTypeDef(const Record *def) : def(def) {
// Populate the builders.
auto *builderList =
dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
const auto *builderList =
dyn_cast_or_null<ListInit>(def->getValueInit("builders"));
if (builderList && !builderList->empty()) {
for (const llvm::Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
def->getLoc());
for (const Init *init : builderList->getValues()) {
AttrOrTypeBuilder builder(cast<DefInit>(init)->getDef(), def->getLoc());

// Ensure that all parameters have names.
for (const AttrOrTypeBuilder::Parameter &param :
Expand All @@ -56,16 +61,16 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {

// Populate the traits.
if (auto *traitList = def->getValueAsListInit("traits")) {
SmallPtrSet<const llvm::Init *, 32> traitSet;
SmallPtrSet<const Init *, 32> traitSet;
traits.reserve(traitSet.size());
llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
[&](const llvm::ListInit *traitList) {
llvm::unique_function<void(const ListInit *)> processTraitList =
[&](const ListInit *traitList) {
for (auto *traitInit : *traitList) {
if (!traitSet.insert(traitInit).second)
continue;

// If this is an interface, add any bases to the trait list.
auto *traitDef = cast<llvm::DefInit>(traitInit)->getDef();
auto *traitDef = cast<DefInit>(traitInit)->getDef();
if (traitDef->isSubClassOf("Interface")) {
if (auto *bases = traitDef->getValueAsListInit("baseInterfaces"))
processTraitList(bases);
Expand Down Expand Up @@ -111,7 +116,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
}

Dialect AttrOrTypeDef::getDialect() const {
auto *dialect = dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
const auto *dialect = dyn_cast<DefInit>(def->getValue("dialect")->getValue());
return Dialect(dialect ? dialect->getDef() : nullptr);
}

Expand All @@ -126,17 +131,17 @@ StringRef AttrOrTypeDef::getCppBaseClassName() const {
}

bool AttrOrTypeDef::hasDescription() const {
const llvm::RecordVal *desc = def->getValue("description");
return desc && isa<llvm::StringInit>(desc->getValue());
const RecordVal *desc = def->getValue("description");
return desc && isa<StringInit>(desc->getValue());
}

StringRef AttrOrTypeDef::getDescription() const {
return def->getValueAsString("description");
}

bool AttrOrTypeDef::hasSummary() const {
const llvm::RecordVal *summary = def->getValue("summary");
return summary && isa<llvm::StringInit>(summary->getValue());
const RecordVal *summary = def->getValue("summary");
return summary && isa<StringInit>(summary->getValue());
}

StringRef AttrOrTypeDef::getSummary() const {
Expand Down Expand Up @@ -249,9 +254,9 @@ StringRef TypeDef::getTypeName() const {
template <typename InitT>
auto AttrOrTypeParameter::getDefValue(StringRef name) const {
std::optional<decltype(std::declval<InitT>().getValue())> result;
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (auto *init = param->getDef()->getValue(name))
if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (const auto *init = param->getDef()->getValue(name))
if (const auto *value = dyn_cast_or_null<InitT>(init->getValue()))
result = value->getValue();
return result;
}
Expand All @@ -270,20 +275,20 @@ std::string AttrOrTypeParameter::getAccessorName() const {
}

std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<llvm::StringInit>("allocator");
return getDefValue<StringInit>("allocator");
}

StringRef AttrOrTypeParameter::getComparator() const {
return getDefValue<llvm::StringInit>("comparator").value_or("$_lhs == $_rhs");
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
}

StringRef AttrOrTypeParameter::getCppType() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
auto cppType = getDefValue<llvm::StringInit>("cppType");
auto cppType = getDefValue<StringInit>("cppType");
if (cppType)
return *cppType;
if (auto *init = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *init = dyn_cast<DefInit>(getDef()))
llvm::PrintFatalError(
init->getDef()->getLoc(),
Twine("Missing `cppType` field in Attribute/Type parameter: ") +
Expand All @@ -295,52 +300,48 @@ StringRef AttrOrTypeParameter::getCppType() const {
}

StringRef AttrOrTypeParameter::getCppAccessorType() const {
return getDefValue<llvm::StringInit>("cppAccessorType")
.value_or(getCppType());
return getDefValue<StringInit>("cppAccessorType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getCppStorageType() const {
return getDefValue<llvm::StringInit>("cppStorageType").value_or(getCppType());
return getDefValue<StringInit>("cppStorageType").value_or(getCppType());
}

StringRef AttrOrTypeParameter::getConvertFromStorage() const {
return getDefValue<llvm::StringInit>("convertFromStorage").value_or("$_self");
return getDefValue<StringInit>("convertFromStorage").value_or("$_self");
}

std::optional<StringRef> AttrOrTypeParameter::getParser() const {
return getDefValue<llvm::StringInit>("parser");
return getDefValue<StringInit>("parser");
}

std::optional<StringRef> AttrOrTypeParameter::getPrinter() const {
return getDefValue<llvm::StringInit>("printer");
return getDefValue<StringInit>("printer");
}

std::optional<StringRef> AttrOrTypeParameter::getSummary() const {
return getDefValue<llvm::StringInit>("summary");
return getDefValue<StringInit>("summary");
}

StringRef AttrOrTypeParameter::getSyntax() const {
if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
if (auto *stringType = dyn_cast<StringInit>(getDef()))
return stringType->getValue();
return getDefValue<llvm::StringInit>("syntax").value_or(getCppType());
return getDefValue<StringInit>("syntax").value_or(getCppType());
}

bool AttrOrTypeParameter::isOptional() const {
return getDefaultValue().has_value();
}

std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
std::optional<StringRef> result =
getDefValue<llvm::StringInit>("defaultValue");
std::optional<StringRef> result = getDefValue<StringInit>("defaultValue");
return result && !result->empty() ? result : std::nullopt;
}

const llvm::Init *AttrOrTypeParameter::getDef() const {
return def->getArg(index);
}
const Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }

std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
if (const auto *param = dyn_cast<DefInit>(getDef()))
if (param->getDef()->isSubClassOf("Constraint"))
return Constraint(param->getDef());
return std::nullopt;
Expand All @@ -351,8 +352,8 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
//===----------------------------------------------------------------------===//

bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
const llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
const Init *paramDef = param->getDef();
if (const auto *paramDefInit = dyn_cast<DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
}
25 changes: 12 additions & 13 deletions mlir/lib/TableGen/Attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ StringRef Attribute::getReturnType() const {
// Return the type constraint corresponding to the type of this attribute, or
// std::nullopt if this is not a TypedAttr.
std::optional<Type> Attribute::getValueType() const {
if (auto *defInit = dyn_cast<llvm::DefInit>(def->getValueInit("valueType")))
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("valueType")))
return Type(defInit->getDef());
return std::nullopt;
}
Expand All @@ -92,8 +92,7 @@ StringRef Attribute::getConstBuilderTemplate() const {
}

Attribute Attribute::getBaseAttr() const {
if (const auto *defInit =
llvm::dyn_cast<llvm::DefInit>(def->getValueInit("baseAttr"))) {
if (const auto *defInit = dyn_cast<DefInit>(def->getValueInit("baseAttr"))) {
return Attribute(defInit).getBaseAttr();
}
return *this;
Expand Down Expand Up @@ -132,7 +131,7 @@ Dialect Attribute::getDialect() const {
return Dialect(nullptr);
}

const llvm::Record &Attribute::getDef() const { return *def; }
const Record &Attribute::getDef() const { return *def; }

ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
Expand All @@ -147,12 +146,12 @@ StringRef ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}

EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrCaseInfo") &&
"must be subclass of TableGen 'EnumAttrInfo' class");
}

EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
EnumAttrCase::EnumAttrCase(const DefInit *init)
: EnumAttrCase(init->getDef()) {}

StringRef EnumAttrCase::getSymbol() const {
Expand All @@ -163,16 +162,16 @@ StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }

int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }

const llvm::Record &EnumAttrCase::getDef() const { return *def; }
const Record &EnumAttrCase::getDef() const { return *def; }

EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
assert(isSubClassOf("EnumAttrInfo") &&
"must be subclass of TableGen 'EnumAttr' class");
}

EnumAttr::EnumAttr(const llvm::Record &record) : Attribute(&record) {}
EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}

EnumAttr::EnumAttr(const llvm::DefInit *init) : EnumAttr(init->getDef()) {}
EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}

bool EnumAttr::classof(const Attribute *attr) {
return attr->isSubClassOf("EnumAttrInfo");
Expand Down Expand Up @@ -218,8 +217,8 @@ std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
std::vector<EnumAttrCase> cases;
cases.reserve(inits->size());

for (const llvm::Init *init : *inits) {
cases.emplace_back(cast<llvm::DefInit>(init));
for (const Init *init : *inits) {
cases.emplace_back(cast<DefInit>(init));
}

return cases;
Expand All @@ -229,7 +228,7 @@ bool EnumAttr::genSpecializedAttr() const {
return def->getValueAsBit("genSpecializedAttr");
}

const llvm::Record *EnumAttr::getBaseAttrClass() const {
const Record *EnumAttr::getBaseAttrClass() const {
return def->getValueAsDef("baseAttrClass");
}

Expand Down
24 changes: 14 additions & 10 deletions mlir/lib/TableGen/Builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Init;
using llvm::Record;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// Builder::Parameter
//===----------------------------------------------------------------------===//

/// Return a string containing the C++ type of this parameter.
StringRef Builder::Parameter::getCppType() const {
if (const auto *stringInit = dyn_cast<llvm::StringInit>(def))
if (const auto *stringInit = dyn_cast<StringInit>(def))
return stringInit->getValue();
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
// Inlining the first part of `Record::getValueAsString` to give better
// error messages.
const llvm::RecordVal *type = record->getValue("type");
Expand All @@ -35,9 +40,9 @@ StringRef Builder::Parameter::getCppType() const {
/// Return an optional string containing the default value to use for this
/// parameter.
std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
if (isa<llvm::StringInit>(def))
if (isa<StringInit>(def))
return std::nullopt;
const llvm::Record *record = cast<llvm::DefInit>(def)->getDef();
const Record *record = cast<DefInit>(def)->getDef();
std::optional<StringRef> value =
record->getValueAsOptionalString("defaultValue");
return value && !value->empty() ? value : std::nullopt;
Expand All @@ -47,18 +52,17 @@ std::optional<StringRef> Builder::Parameter::getDefaultValue() const {
// Builder
//===----------------------------------------------------------------------===//

Builder::Builder(const llvm::Record *record, ArrayRef<SMLoc> loc)
: def(record) {
Builder::Builder(const Record *record, ArrayRef<SMLoc> loc) : def(record) {
// Initialize the parameters of the builder.
const llvm::DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<llvm::DefInit>(dag->getOperator());
const DagInit *dag = def->getValueAsDag("dagParams");
auto *defInit = dyn_cast<DefInit>(dag->getOperator());
if (!defInit || defInit->getDef()->getName() != "ins")
PrintFatalError(def->getLoc(), "expected 'ins' in builders");

bool seenDefaultValue = false;
for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) {
const llvm::StringInit *paramName = dag->getArgName(i);
const llvm::Init *paramValue = dag->getArg(i);
const StringInit *paramName = dag->getArgName(i);
const Init *paramValue = dag->getArg(i);
Parameter param(paramName ? paramName->getValue()
: std::optional<StringRef>(),
paramValue);
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,32 @@ using namespace mlir::tblgen;

/// Generate a unique label based on the current file name to prevent name
/// collisions if multiple generated files are included at once.
static std::string getUniqueOutputLabel(const llvm::RecordKeeper &records,
static std::string getUniqueOutputLabel(const RecordKeeper &records,
StringRef tag) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();

// Drop all but the base filename.
StringRef nameRef = llvm::sys::path::filename(inputFilename);
StringRef nameRef = sys::path::filename(inputFilename);
nameRef.consume_back(".td");

// Sanitize any invalid characters.
std::string uniqueName(tag);
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
if (isAlnum(c) || c == '_')
uniqueName.push_back(c);
else
uniqueName.append(llvm::utohexstr((unsigned char)c));
uniqueName.append(utohexstr((unsigned char)c));
}
return uniqueName;
}

StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
raw_ostream &os, const llvm::RecordKeeper &records, StringRef tag)
raw_ostream &os, const RecordKeeper &records, StringRef tag)
: os(os), uniqueOutputLabel(getUniqueOutputLabel(records, tag)) {}

void StaticVerifierFunctionEmitter::emitOpConstraints(
ArrayRef<const llvm::Record *> opDefs) {
ArrayRef<const Record *> opDefs) {
NamespaceEmitter namespaceEmitter(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraints();
emitAttrConstraints();
Expand All @@ -58,7 +58,7 @@ void StaticVerifierFunctionEmitter::emitOpConstraints(
}

void StaticVerifierFunctionEmitter::emitPatternConstraints(
const llvm::ArrayRef<DagLeaf> constraints) {
const ArrayRef<DagLeaf> constraints) {
collectPatternConstraints(constraints);
emitPatternConstraints();
}
Expand Down Expand Up @@ -298,7 +298,7 @@ void StaticVerifierFunctionEmitter::collectOpConstraints(
}

void StaticVerifierFunctionEmitter::collectPatternConstraints(
const llvm::ArrayRef<DagLeaf> constraints) {
const ArrayRef<DagLeaf> constraints) {
for (auto &leaf : constraints) {
assert(leaf.isOperandMatcher() || leaf.isAttrMatcher());
collectConstraint(
Expand All @@ -313,7 +313,7 @@ void StaticVerifierFunctionEmitter::collectPatternConstraints(

std::string mlir::tblgen::escapeString(StringRef value) {
std::string ret;
llvm::raw_string_ostream os(ret);
raw_string_ostream os(ret);
os.write_escaped(value);
return ret;
}
30 changes: 17 additions & 13 deletions mlir/lib/TableGen/Interfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@

using namespace mlir;
using namespace mlir::tblgen;
using llvm::DagInit;
using llvm::DefInit;
using llvm::Init;
using llvm::ListInit;
using llvm::Record;
using llvm::StringInit;

//===----------------------------------------------------------------------===//
// InterfaceMethod
//===----------------------------------------------------------------------===//

InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
const llvm::DagInit *args = def->getValueAsDag("arguments");
InterfaceMethod::InterfaceMethod(const Record *def) : def(def) {
const DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
}
}

Expand Down Expand Up @@ -72,18 +77,17 @@ bool InterfaceMethod::arg_empty() const { return arguments.empty(); }
// Interface
//===----------------------------------------------------------------------===//

Interface::Interface(const llvm::Record *def) : def(def) {
Interface::Interface(const Record *def) : def(def) {
assert(def->isSubClassOf("Interface") &&
"must be subclass of TableGen 'Interface' class");

// Initialize the interface methods.
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (const llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods"));
for (const Init *init : listInit->getValues())
methods.emplace_back(cast<DefInit>(init)->getDef());

// Initialize the interface base classes.
auto *basesInit =
dyn_cast<llvm::ListInit>(def->getValueInit("baseInterfaces"));
auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces"));
// Chained inheritance will produce duplicates in the base interface set.
StringSet<> basesAdded;
llvm::unique_function<void(Interface)> addBaseInterfaceFn =
Expand All @@ -98,8 +102,8 @@ Interface::Interface(const llvm::Record *def) : def(def) {
baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
basesAdded.insert(baseInterface.getName());
};
for (const llvm::Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
for (const Init *init : basesInit->getValues())
addBaseInterfaceFn(Interface(cast<DefInit>(init)->getDef()));
}

// Return the name of this interface.
Expand Down
Loading