90 changes: 90 additions & 0 deletions libc/test/src/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ add_fp_unittest(
TruncTest.h
DEPENDS
libc.src.math.trunc
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -155,6 +156,7 @@ add_fp_unittest(
TruncTest.h
DEPENDS
libc.src.math.truncf
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -169,6 +171,22 @@ add_fp_unittest(
TruncTest.h
DEPENDS
libc.src.math.truncl
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
truncf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
truncf16_test.cpp
HDRS
TruncTest.h
DEPENDS
libc.src.math.truncf16
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -183,6 +201,7 @@ add_fp_unittest(
CeilTest.h
DEPENDS
libc.src.math.ceil
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -197,6 +216,7 @@ add_fp_unittest(
CeilTest.h
DEPENDS
libc.src.math.ceilf
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -211,6 +231,22 @@ add_fp_unittest(
CeilTest.h
DEPENDS
libc.src.math.ceill
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
ceilf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
ceilf16_test.cpp
HDRS
CeilTest.h
DEPENDS
libc.src.math.ceilf16
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -225,6 +261,7 @@ add_fp_unittest(
FloorTest.h
DEPENDS
libc.src.math.floor
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -239,6 +276,7 @@ add_fp_unittest(
FloorTest.h
DEPENDS
libc.src.math.floorf
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -253,6 +291,22 @@ add_fp_unittest(
FloorTest.h
DEPENDS
libc.src.math.floorl
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
floorf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
floorf16_test.cpp
HDRS
FloorTest.h
DEPENDS
libc.src.math.floorf16
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -267,6 +321,7 @@ add_fp_unittest(
RoundTest.h
DEPENDS
libc.src.math.round
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -281,6 +336,7 @@ add_fp_unittest(
RoundTest.h
DEPENDS
libc.src.math.roundf
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -295,6 +351,22 @@ add_fp_unittest(
RoundTest.h
DEPENDS
libc.src.math.roundl
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
roundf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
roundf16_test.cpp
HDRS
RoundTest.h
DEPENDS
libc.src.math.roundf16
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -309,6 +381,7 @@ add_fp_unittest(
RoundEvenTest.h
DEPENDS
libc.src.math.roundeven
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -323,6 +396,7 @@ add_fp_unittest(
RoundEvenTest.h
DEPENDS
libc.src.math.roundevenf
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand All @@ -337,6 +411,22 @@ add_fp_unittest(
RoundEvenTest.h
DEPENDS
libc.src.math.roundevenl
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

add_fp_unittest(
roundevenf16_test
NEED_MPFR
SUITE
libc-math-unittests
SRCS
roundevenf16_test.cpp
HDRS
RoundEvenTest.h
DEPENDS
libc.src.math.roundevenf16
libc.src.__support.CPP.algorithm
libc.src.__support.FPUtil.fp_bits
)

Expand Down
27 changes: 18 additions & 9 deletions libc/test/src/math/CeilTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_TEST_SRC_MATH_CEILTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_CEILTEST_H

#include "src/__support/CPP/algorithm.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
Expand Down Expand Up @@ -59,18 +63,21 @@ class CeilTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
EXPECT_FP_EQ(T(-10.0), func(T(-10.32)));
EXPECT_FP_EQ(T(11.0), func(T(10.65)));
EXPECT_FP_EQ(T(-10.0), func(T(-10.65)));
EXPECT_FP_EQ(T(1235.0), func(T(1234.38)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.38)));
EXPECT_FP_EQ(T(1235.0), func(T(1234.96)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.96)));
EXPECT_FP_EQ(T(124.0), func(T(123.38)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.38)));
EXPECT_FP_EQ(T(124.0), func(T(123.96)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.96)));
}

void testRange(CeilFunc func) {
constexpr StorageType COUNT = 100'000;
constexpr StorageType STEP = STORAGE_MAX / COUNT;
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = FPBits(v).get_val();
if (isnan(x) || isinf(x))
constexpr int COUNT = 100'000;
constexpr StorageType STEP = LIBC_NAMESPACE::cpp::max(
static_cast<StorageType>(STORAGE_MAX / COUNT), StorageType(1));
StorageType v = 0;
for (int i = 0; i <= COUNT; ++i, v += STEP) {
FPBits xbits(v);
T x = xbits.get_val();
if (xbits.is_inf_or_nan())
continue;

ASSERT_MPFR_MATCH(mpfr::Operation::Ceil, x, func(x), 0.0);
Expand All @@ -84,3 +91,5 @@ class CeilTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
TEST_F(LlvmLibcCeilTest, RoundedNubmers) { testRoundedNumbers(&func); } \
TEST_F(LlvmLibcCeilTest, Fractions) { testFractions(&func); } \
TEST_F(LlvmLibcCeilTest, Range) { testRange(&func); }

#endif // LLVM_LIBC_TEST_SRC_MATH_CEILTEST_H
22 changes: 13 additions & 9 deletions libc/test/src/math/FloorTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_LIBC_TEST_SRC_MATH_FLOORTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_FLOORTEST_H

#include "src/__support/CPP/algorithm.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
Expand Down Expand Up @@ -62,18 +63,21 @@ class FloorTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
EXPECT_FP_EQ(T(-11.0), func(T(-10.32)));
EXPECT_FP_EQ(T(10.0), func(T(10.65)));
EXPECT_FP_EQ(T(-11.0), func(T(-10.65)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.38)));
EXPECT_FP_EQ(T(-1235.0), func(T(-1234.38)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.96)));
EXPECT_FP_EQ(T(-1235.0), func(T(-1234.96)));
EXPECT_FP_EQ(T(123.0), func(T(123.38)));
EXPECT_FP_EQ(T(-124.0), func(T(-123.38)));
EXPECT_FP_EQ(T(123.0), func(T(123.96)));
EXPECT_FP_EQ(T(-124.0), func(T(-123.96)));
}

void testRange(FloorFunc func) {
constexpr StorageType COUNT = 100'000;
constexpr StorageType STEP = STORAGE_MAX / COUNT;
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = FPBits(v).get_val();
if (isnan(x) || isinf(x))
constexpr int COUNT = 100'000;
constexpr StorageType STEP = LIBC_NAMESPACE::cpp::max(
static_cast<StorageType>(STORAGE_MAX / COUNT), StorageType(1));
StorageType v = 0;
for (int i = 0; i <= COUNT; ++i, v += STEP) {
FPBits xbits(v);
T x = xbits.get_val();
if (xbits.is_inf_or_nan())
continue;

ASSERT_MPFR_MATCH(mpfr::Operation::Floor, x, func(x), 0.0);
Expand Down
30 changes: 17 additions & 13 deletions libc/test/src/math/RoundEvenTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_LIBC_TEST_SRC_MATH_ROUNDEVENTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_ROUNDEVENTEST_H

#include "src/__support/CPP/algorithm.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
Expand Down Expand Up @@ -60,22 +61,25 @@ class RoundEvenTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
EXPECT_FP_EQ(T(-2.0), func(T(-1.75)));
EXPECT_FP_EQ(T(11.0), func(T(10.65)));
EXPECT_FP_EQ(T(-11.0), func(T(-10.65)));
EXPECT_FP_EQ(T(1233.0), func(T(1233.25)));
EXPECT_FP_EQ(T(1234.0), func(T(1233.50)));
EXPECT_FP_EQ(T(1234.0), func(T(1233.75)));
EXPECT_FP_EQ(T(-1233.0), func(T(-1233.25)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1233.50)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1233.75)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.50)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.50)));
EXPECT_FP_EQ(T(123.0), func(T(123.25)));
EXPECT_FP_EQ(T(124.0), func(T(123.50)));
EXPECT_FP_EQ(T(124.0), func(T(123.75)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.25)));
EXPECT_FP_EQ(T(-124.0), func(T(-123.50)));
EXPECT_FP_EQ(T(-124.0), func(T(-123.75)));
EXPECT_FP_EQ(T(124.0), func(T(124.50)));
EXPECT_FP_EQ(T(-124.0), func(T(-124.50)));
}

void testRange(RoundEvenFunc func) {
constexpr StorageType COUNT = 100'000;
constexpr StorageType STEP = STORAGE_MAX / COUNT;
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = FPBits(v).get_val();
if (isnan(x) || isinf(x))
constexpr int COUNT = 100'000;
constexpr StorageType STEP = LIBC_NAMESPACE::cpp::max(
static_cast<StorageType>(STORAGE_MAX / COUNT), StorageType(1));
StorageType v = 0;
for (int i = 0; i <= COUNT; ++i, v += STEP) {
FPBits xbits(v);
T x = xbits.get_val();
if (xbits.is_inf_or_nan())
continue;

ASSERT_MPFR_MATCH(mpfr::Operation::RoundEven, x, func(x), 0.0);
Expand Down
22 changes: 13 additions & 9 deletions libc/test/src/math/RoundTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_LIBC_TEST_SRC_MATH_ROUNDTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_ROUNDTEST_H

#include "src/__support/CPP/algorithm.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
Expand Down Expand Up @@ -62,18 +63,21 @@ class RoundTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
EXPECT_FP_EQ(T(-10.0), func(T(-10.32)));
EXPECT_FP_EQ(T(11.0), func(T(10.65)));
EXPECT_FP_EQ(T(-11.0), func(T(-10.65)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.38)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.38)));
EXPECT_FP_EQ(T(1235.0), func(T(1234.96)));
EXPECT_FP_EQ(T(-1235.0), func(T(-1234.96)));
EXPECT_FP_EQ(T(123.0), func(T(123.38)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.38)));
EXPECT_FP_EQ(T(124.0), func(T(123.96)));
EXPECT_FP_EQ(T(-124.0), func(T(-123.96)));
}

void testRange(RoundFunc func) {
constexpr StorageType COUNT = 100'000;
constexpr StorageType STEP = STORAGE_MAX / COUNT;
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = FPBits(v).get_val();
if (isnan(x) || isinf(x))
constexpr int COUNT = 100'000;
constexpr StorageType STEP = LIBC_NAMESPACE::cpp::max(
static_cast<StorageType>(STORAGE_MAX / COUNT), StorageType(1));
StorageType v = 0;
for (int i = 0; i <= COUNT; ++i, v += STEP) {
FPBits xbits(v);
T x = xbits.get_val();
if (xbits.is_inf_or_nan())
continue;

ASSERT_MPFR_MATCH(mpfr::Operation::Round, x, func(x), 0.0);
Expand Down
22 changes: 13 additions & 9 deletions libc/test/src/math/TruncTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef LLVM_LIBC_TEST_SRC_MATH_TRUNCTEST_H
#define LLVM_LIBC_TEST_SRC_MATH_TRUNCTEST_H

#include "src/__support/CPP/algorithm.h"
#include "test/UnitTest/FEnvSafeTest.h"
#include "test/UnitTest/FPMatcher.h"
#include "test/UnitTest/Test.h"
Expand Down Expand Up @@ -62,18 +63,21 @@ class TruncTest : public LIBC_NAMESPACE::testing::FEnvSafeTest {
EXPECT_FP_EQ(T(-10.0), func(T(-10.32)));
EXPECT_FP_EQ(T(10.0), func(T(10.65)));
EXPECT_FP_EQ(T(-10.0), func(T(-10.65)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.38)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.38)));
EXPECT_FP_EQ(T(1234.0), func(T(1234.96)));
EXPECT_FP_EQ(T(-1234.0), func(T(-1234.96)));
EXPECT_FP_EQ(T(123.0), func(T(123.38)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.38)));
EXPECT_FP_EQ(T(123.0), func(T(123.96)));
EXPECT_FP_EQ(T(-123.0), func(T(-123.96)));
}

void testRange(TruncFunc func) {
constexpr StorageType COUNT = 100'000;
constexpr StorageType STEP = STORAGE_MAX / COUNT;
for (StorageType i = 0, v = 0; i <= COUNT; ++i, v += STEP) {
T x = FPBits(v).get_val();
if (isnan(x) || isinf(x))
constexpr int COUNT = 100'000;
constexpr StorageType STEP = LIBC_NAMESPACE::cpp::max(
static_cast<StorageType>(STORAGE_MAX / COUNT), StorageType(1));
StorageType v = 0;
for (int i = 0; i <= COUNT; ++i, v += STEP) {
FPBits xbits(v);
T x = xbits.get_val();
if (xbits.is_inf_or_nan())
continue;

ASSERT_MPFR_MATCH(mpfr::Operation::Trunc, x, func(x), 0.0);
Expand Down
13 changes: 13 additions & 0 deletions libc/test/src/math/ceilf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for ceilf16 ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CeilTest.h"

#include "src/math/ceilf16.h"

LIST_CEIL_TESTS(float16, LIBC_NAMESPACE::ceilf16)
13 changes: 13 additions & 0 deletions libc/test/src/math/floorf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for floorf16 --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "FloorTest.h"

#include "src/math/floorf16.h"

LIST_FLOOR_TESTS(float16, LIBC_NAMESPACE::floorf16)
13 changes: 13 additions & 0 deletions libc/test/src/math/roundevenf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for roundevenf16 ----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RoundEvenTest.h"

#include "src/math/roundevenf16.h"

LIST_ROUNDEVEN_TESTS(float16, LIBC_NAMESPACE::roundevenf16)
13 changes: 13 additions & 0 deletions libc/test/src/math/roundf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for roundf16 --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RoundTest.h"

#include "src/math/roundf16.h"

LIST_ROUND_TESTS(float16, LIBC_NAMESPACE::roundf16)
13 changes: 13 additions & 0 deletions libc/test/src/math/truncf16_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- Unittests for truncf16 --------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TruncTest.h"

#include "src/math/truncf16.h"

LIST_TRUNC_TESTS(float16, LIBC_NAMESPACE::truncf16)
39 changes: 35 additions & 4 deletions libc/utils/MPFRWrapper/MPFRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "src/__support/CPP/string_view.h"
#include "src/__support/FPUtil/FPBits.h"
#include "src/__support/FPUtil/fpbits_str.h"
#include "src/__support/macros/properties/types.h"
#include "test/UnitTest/FPMatcher.h"

#include "hdr/math_macros.h"
Expand All @@ -30,6 +31,12 @@ namespace mpfr {
// precision compared to the floating point precision.
template <typename T> struct ExtraPrecision;

#ifdef LIBC_TYPES_HAS_FLOAT16
template <> struct ExtraPrecision<float16> {
static constexpr unsigned int VALUE = 128;
};
#endif

template <> struct ExtraPrecision<float> {
static constexpr unsigned int VALUE = 128;
};
Expand Down Expand Up @@ -85,9 +92,16 @@ class MPFRNumber {

// We use explicit EnableIf specializations to disallow implicit
// conversions. Implicit conversions can potentially lead to loss of
// precision.
// precision. We exceptionally allow implicit conversions from float16
// to float, as the MPFR API does not support float16, thus requiring
// conversion to a higher-precision format.
template <typename XType,
cpp::enable_if_t<cpp::is_same_v<float, XType>, int> = 0>
cpp::enable_if_t<cpp::is_same_v<float, XType>
#ifdef LIBC_TYPES_HAS_FLOAT16
|| cpp::is_same_v<float16, XType>
#endif
,
int> = 0>
explicit MPFRNumber(XType x,
unsigned int precision = ExtraPrecision<XType>::VALUE,
RoundingMode rounding = RoundingMode::Nearest)
Expand Down Expand Up @@ -529,8 +543,8 @@ class MPFRNumber {
// If the control reaches here, it means that this number and input are
// of the same sign but different exponent. In such a case, ULP error is
// calculated as sum of two parts.
thisAsT = std::abs(thisAsT);
input = std::abs(input);
thisAsT = FPBits<T>(thisAsT).abs().get_val();
input = FPBits<T>(input).abs().get_val();
T min = thisAsT > input ? input : thisAsT;
T max = thisAsT > input ? thisAsT : input;
int minExponent = FPBits<T>(min).get_exponent();
Expand Down Expand Up @@ -585,6 +599,14 @@ template <> long double MPFRNumber::as<long double>() const {
return mpfr_get_ld(value, mpfr_rounding);
}

#ifdef LIBC_TYPES_HAS_FLOAT16
template <> float16 MPFRNumber::as<float16>() const {
// TODO: Either prove that this cast won't cause double-rounding errors, or
// find a better way to get a float16.
return static_cast<float16>(mpfr_get_d(value, mpfr_rounding));
}
#endif

namespace internal {

template <typename InputType>
Expand Down Expand Up @@ -763,6 +785,10 @@ template void explain_unary_operation_single_output_error<double>(
Operation op, double, double, double, RoundingMode);
template void explain_unary_operation_single_output_error<long double>(
Operation op, long double, long double, double, RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
template void explain_unary_operation_single_output_error<float16>(
Operation op, float16, float16, double, RoundingMode);
#endif

template <typename T>
void explain_unary_operation_two_outputs_error(
Expand Down Expand Up @@ -942,6 +968,11 @@ template bool compare_unary_operation_single_output<double>(Operation, double,
RoundingMode);
template bool compare_unary_operation_single_output<long double>(
Operation, long double, long double, double, RoundingMode);
#ifdef LIBC_TYPES_HAS_FLOAT16
template bool compare_unary_operation_single_output<float16>(Operation, float16,
float16, double,
RoundingMode);
#endif

template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/MC/MCContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ class MCContext {
std::function<void(SMDiagnostic &, const SourceMgr *)>);

MCSymbol *createSymbolImpl(const StringMapEntry<bool> *Name,
bool CanBeUnnamed);
bool IsTemporary);
MCSymbol *createSymbol(StringRef Name, bool AlwaysAddSuffix,
bool IsTemporary);

Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/MC/MCContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,9 @@ MCSymbol *MCContext::createSymbolImpl(const StringMapEntry<bool> *Name,
}

MCSymbol *MCContext::createSymbol(StringRef Name, bool AlwaysAddSuffix,
bool CanBeUnnamed) {
if (CanBeUnnamed && !UseNamesOnTempLabels)
return createSymbolImpl(nullptr, true);

bool IsTemporary) {
// Determine whether this is a user written assembler temporary or normal
// label, if used.
bool IsTemporary = CanBeUnnamed;
if (AllowTemporaryLabels && !IsTemporary)
IsTemporary = Name.starts_with(MAI->getPrivateGlobalPrefix());

Expand Down Expand Up @@ -298,6 +294,9 @@ MCSymbol *MCContext::createSymbol(StringRef Name, bool AlwaysAddSuffix,
}

MCSymbol *MCContext::createTempSymbol(const Twine &Name, bool AlwaysAddSuffix) {
if (!UseNamesOnTempLabels)
return createSymbolImpl(nullptr, /*IsTemporary=*/true);

SmallString<128> NameSV;
raw_svector_ostream(NameSV) << MAI->getPrivateGlobalPrefix() << Name;
return createSymbol(NameSV, AlwaysAddSuffix, true);
Expand Down
5 changes: 4 additions & 1 deletion llvm/lib/MC/MCObjectStreamer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ void MCObjectStreamer::emitAbsoluteSymbolDiffAsULEB128(const MCSymbol *Hi,
}

void MCObjectStreamer::reset() {
if (Assembler)
if (Assembler) {
Assembler->reset();
if (getContext().getTargetOptions())
Assembler->setRelaxAll(getContext().getTargetOptions()->MCRelaxAll);
}
CurInsertionPoint = MCSection::iterator();
EmitEHFrame = true;
EmitDebugFrame = false;
Expand Down
78 changes: 1 addition & 77 deletions llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3287,76 +3287,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

// Convert `Mask` into `<n x i1>`.
Constant *createDppMask(unsigned Width, unsigned Mask) {
SmallVector<Constant *, 4> R;
R.assign(Width, ConstantInt::getFalse(F.getContext()));
for (auto &M : R) {
if (Mask & 1)
M = ConstantInt::getTrue(F.getContext());
Mask >>= 1;
}
return ConstantVector::get(R);
}

// Calculate output shadow as array of booleans `<n x i1>`, assuming if any
// arg is poisoned, entire dot product is poisoned.
Value *makeDppShadowI1(IRBuilder<> &IRB, Value *S, unsigned SrcMask,
unsigned DstMask) {
const unsigned Width =
cast<FixedVectorType>(S->getType())->getNumElements();

S = IRB.CreateSelect(createDppMask(Width, SrcMask), S,
Constant::getNullValue(S->getType()));
Value *SElem = IRB.CreateOrReduce(S);
Value *IsClean = IRB.CreateIsNull(SElem, "_msdpp");
Value *DstMaskV = createDppMask(Width, DstMask);

return IRB.CreateSelect(
IsClean, Constant::getNullValue(DstMaskV->getType()), DstMaskV);
}

// See `Intel Intrinsics Guide` for `_dp_p*` instructions.
//
// 2 and 4 element versions produce single scalar of dot product, and then
// puts it into elements of output vector, selected by 4 lowest bits of the
// mask. Top 4 bits of the mask control which elements of input to use for dot
// product.
//
// 8 element version mask still has only 4 bit for input, and 4 bit for output
// mask. According to the spec it just operates as 4 element version on first
// 4 elements of inputs and output, and then on last 4 elements of inputs and
// output.
void handleDppIntrinsic(IntrinsicInst &I) {
IRBuilder<> IRB(&I);

Value *S0 = getShadow(&I, 0);
Value *S1 = getShadow(&I, 1);
Value *S = IRB.CreateOr(S0, S1);

const unsigned Width =
cast<FixedVectorType>(S->getType())->getNumElements();
assert(Width == 2 || Width == 4 || Width == 8);

const unsigned Mask = cast<ConstantInt>(I.getArgOperand(2))->getZExtValue();
const unsigned SrcMask = Mask >> 4;
const unsigned DstMask = Mask & 0xf;

// Calculate shadow as `<n x i1>`.
Value *SI1 = makeDppShadowI1(IRB, S, SrcMask, DstMask);
if (Width == 8) {
// First 4 elements of shadow are already calculated. `makeDppShadow`
// operats on 32 bit masks, so we can just shift masks, and repeat.
SI1 = IRB.CreateOr(SI1,
makeDppShadowI1(IRB, S, SrcMask << 4, DstMask << 4));
}
// Extend to real size of shadow, poisoning all no none bits of an element.
S = IRB.CreateSExt(SI1, S->getType(), "_msdpp");

setShadow(&I, S);
setOriginForNaryOp(I);
}

// Instrument sum-of-absolute-differences intrinsic.
void handleVectorSadIntrinsic(IntrinsicInst &I) {
const unsigned SignificantBitsPerResultElement = 16;
Expand Down Expand Up @@ -3712,7 +3642,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
setOriginForNaryOp(I);
}

static SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) {
SmallVector<int, 8> getPclmulMask(unsigned Width, bool OddElements) {
SmallVector<int, 8> Mask;
for (unsigned X = OddElements ? 1 : 0; X < Width; X += 2) {
Mask.append(2, X);
Expand Down Expand Up @@ -4028,12 +3958,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
handleVectorPackIntrinsic(I);
break;

case Intrinsic::x86_avx_dp_ps_256:
case Intrinsic::x86_sse41_dppd:
case Intrinsic::x86_sse41_dpps:
handleDppIntrinsic(I);
break;

case Intrinsic::x86_mmx_packsswb:
case Intrinsic::x86_mmx_packuswb:
handleVectorPackIntrinsic(I, 16);
Expand Down
72 changes: 58 additions & 14 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup = nullptr);
Instruction *Fixup = nullptr,
int64_t ScalableOffset = 0);

static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
Expand Down Expand Up @@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup/*= nullptr*/) {
Instruction *Fixup /* = nullptr */,
int64_t ScalableOffset) {
switch (Kind) {
case LSRUse::Address:
return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
HasBaseReg, Scale, AccessTy.AddrSpace,
Fixup, ScalableOffset);

case LSRUse::ICmpZero:
// There's not even a target hook for querying whether it would be legal to
// fold a GV into an ICmp.
if (BaseGV)
if (BaseGV || ScalableOffset != 0)
return false;

// ICmp only has two operands; don't allow more than two non-trivial parts.
Expand Down Expand Up @@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,

case LSRUse::Basic:
// Only handle single-register values.
return !BaseGV && Scale == 0 && BaseOffset == 0;
return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;

case LSRUse::Special:
// Special case Basic to handle -1 scales.
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
ScalableOffset == 0;
}

llvm_unreachable("Invalid LSRUse Kind!");
Expand Down Expand Up @@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg) {
bool HasBaseReg, int64_t ScalableOffset = 0) {
// Fast-path: zero is always foldable.
if (BaseOffset == 0 && !BaseGV) return true;

Expand All @@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
}

return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
HasBaseReg, Scale);
HasBaseReg, Scale, nullptr, ScalableOffset);
}

static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
Expand Down Expand Up @@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
Value *Operand, const TargetTransformInfo &TTI) {
const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
return false;
int64_t IncOffset = 0;
int64_t ScalableOffset = 0;
if (IncConst) {
if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
return false;
IncOffset = IncConst->getValue()->getSExtValue();
} else {
// Look for mul(vscale, constant), to detect ScalableOffset.
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
!isa<SCEVVScale>(IncVScale->getOperand(1)))
return false;
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
return false;
ScalableOffset = Scale->getValue()->getSExtValue();
}

if (IncConst->getAPInt().getSignificantBits() > 64)
if (!isAddressUse(TTI, UserInst, Operand))
return false;

MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
int64_t IncOffset = IncConst->getValue()->getSExtValue();
if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
IncOffset, /*HasBaseReg=*/false))
IncOffset, /*HasBaseReg=*/false, ScalableOffset))
return false;

return true;
Expand Down Expand Up @@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
Type *IVTy = IVSrc->getType();
Type *IntTy = SE.getEffectiveSCEVType(IVTy);
const SCEV *LeftOverExpr = nullptr;
const SCEV *Accum = SE.getZero(IntTy);
SmallVector<std::pair<const SCEV *, Value *>> Bases;
Bases.emplace_back(Accum, IVSrc);

for (const IVInc &Inc : Chain) {
Instruction *InsertPt = Inc.UserInst;
if (isa<PHINode>(InsertPt))
Expand All @@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// IncExpr was the result of subtraction of two narrow values, so must
// be signed.
const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
Accum = SE.getAddExpr(Accum, IncExpr);
LeftOverExpr = LeftOverExpr ?
SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
}
if (LeftOverExpr && !LeftOverExpr->isZero()) {

// Look through each base to see if any can produce a nice addressing mode.
bool FoundBase = false;
for (auto [MapScev, MapIVOper] : reverse(Bases)) {
const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
if (!Remainder->isZero()) {
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
const SCEV *IVOperExpr =
SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
} else {
IVOper = MapIVOper;
}

FoundBase = true;
break;
}
}
if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
// Expand the IV increment.
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
Expand All @@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// If an IV increment can't be folded, use it as the next IV value.
if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
Bases.emplace_back(Accum, IVOper);
IVSrc = IVOper;
LeftOverExpr = nullptr;
}
Expand Down
86 changes: 40 additions & 46 deletions llvm/test/CodeGen/AArch64/sve-lsrchain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: // %bb.2: // %for.body.us.preheader
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: add x11, x2, x11, lsl #1
; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov w8, wzr
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov x9, xzr
; CHECK-NEXT: mov w10, wzr
; CHECK-NEXT: addvl x12, x12, #1
; CHECK-NEXT: mov x13, #4 // =0x4
; CHECK-NEXT: mov x14, #8 // =0x8
; CHECK-NEXT: mov x12, #4 // =0x4
; CHECK-NEXT: mov x13, #8 // =0x8
; CHECK-NEXT: .LBB0_3: // %for.body.us
; CHECK-NEXT: // =>This Loop Header: Depth=1
; CHECK-NEXT: // Child Loop BB0_4 Depth 2
; CHECK-NEXT: add x15, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x16, x8, #1, #32
; CHECK-NEXT: mov x17, x2
; CHECK-NEXT: ldp s0, s1, [x15]
; CHECK-NEXT: add x16, x16, #8
; CHECK-NEXT: ldp s2, s3, [x15, #8]
; CHECK-NEXT: ubfiz x15, x8, #1, #32
; CHECK-NEXT: add x14, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x15, x8, #1, #32
; CHECK-NEXT: mov x16, x2
; CHECK-NEXT: ldp s0, s1, [x14]
; CHECK-NEXT: add x15, x15, #8
; CHECK-NEXT: ldp s2, s3, [x14, #8]
; CHECK-NEXT: ubfiz x14, x8, #1, #32
; CHECK-NEXT: fcvt h0, s0
; CHECK-NEXT: fcvt h1, s1
; CHECK-NEXT: fcvt h2, s2
Expand All @@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us
; CHECK-NEXT: // Parent Loop BB0_3 Depth=1
; CHECK-NEXT: // => This Inner Loop Header: Depth=2
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17]
; CHECK-NEXT: add x18, x17, x16
; CHECK-NEXT: add x3, x17, x15
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16]
; CHECK-NEXT: add x17, x16, x15
; CHECK-NEXT: add x18, x16, x14
; CHECK-NEXT: add x3, x17, #8
; CHECK-NEXT: add x4, x17, #16
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16]
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl]
; CHECK-NEXT: addvl x17, x17, #4
; CHECK-NEXT: cmp x17, x11
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl]
; CHECK-NEXT: addvl x16, x16, #4
; CHECK-NEXT: cmp x16, x11
; CHECK-NEXT: b.lo .LBB0_4
; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us
; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1
Expand Down
23 changes: 11 additions & 12 deletions llvm/test/Instrumentation/MemorySanitizer/X86/avx-intrinsics-x86.ll
Original file line number Diff line number Diff line change
Expand Up @@ -389,19 +389,18 @@ define <8 x float> @test_x86_avx_dp_ps_256(<8 x float> %a0, <8 x float> %a1) #0
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <8 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 32) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP3:%.*]] = or <8 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> <i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>, <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TMP4]])
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i32 [[TMP5]], 0
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <8 x i1> zeroinitializer, <8 x i1> <i1 false, i1 true, i1 true, i1 true, i1 false, i1 false, i1 false, i1 false>
; CHECK-NEXT: [[TMP7:%.*]] = select <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true>, <8 x i32> [[TMP3]], <8 x i32> zeroinitializer
; CHECK-NEXT: [[TMP8:%.*]] = call i32 @llvm.vector.reduce.or.v8i32(<8 x i32> [[TMP7]])
; CHECK-NEXT: [[_MSDPP1:%.*]] = icmp eq i32 [[TMP8]], 0
; CHECK-NEXT: [[TMP9:%.*]] = select i1 [[_MSDPP1]], <8 x i1> zeroinitializer, <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 true, i1 true, i1 true>
; CHECK-NEXT: [[TMP10:%.*]] = or <8 x i1> [[TMP6]], [[TMP9]]
; CHECK-NEXT: [[_MSDPP2:%.*]] = sext <8 x i1> [[TMP10]] to <8 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i32> [[TMP1]] to i256
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i256 [[TMP3]], 0
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP2]] to i256
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i256 [[TMP4]], 0
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
; CHECK: 5:
; CHECK-NEXT: call void @__msan_warning_noreturn()
; CHECK-NEXT: unreachable
; CHECK: 6:
; CHECK-NEXT: [[RES:%.*]] = call <8 x float> @llvm.x86.avx.dp.ps.256(<8 x float> [[A0:%.*]], <8 x float> [[A1:%.*]], i8 -18)
; CHECK-NEXT: store <8 x i32> [[_MSDPP2]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <8 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <8 x float> [[RES]]
;
%res = call <8 x float> @llvm.x86.avx.dp.ps.256(<8 x float> %a0, <8 x float> %a1, i8 -18) ; <<8 x float>> [#uses=1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ define <2 x double> @test_x86_sse41_dppd(<2 x double> %a0, <2 x double> %a1) #0
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 16) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP3:%.*]] = or <2 x i64> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = select <2 x i1> <i1 false, i1 true>, <2 x i64> [[TMP3]], <2 x i64> zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = call i64 @llvm.vector.reduce.or.v2i64(<2 x i64> [[TMP4]])
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i64 [[TMP5]], 0
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <2 x i1> zeroinitializer, <2 x i1> <i1 false, i1 true>
; CHECK-NEXT: [[_MSDPP1:%.*]] = sext <2 x i1> [[TMP6]] to <2 x i64>
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x i64> [[TMP1]] to i128
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <2 x i64> [[TMP2]] to i128
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK: 5:
; CHECK-NEXT: call void @__msan_warning_noreturn()
; CHECK-NEXT: unreachable
; CHECK: 6:
; CHECK-NEXT: [[RES:%.*]] = call <2 x double> @llvm.x86.sse41.dppd(<2 x double> [[A0:%.*]], <2 x double> [[A1:%.*]], i8 -18)
; CHECK-NEXT: store <2 x i64> [[_MSDPP1]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <2 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <2 x double> [[RES]]
;
%res = call <2 x double> @llvm.x86.sse41.dppd(<2 x double> %a0, <2 x double> %a1, i8 -18) ; <<2 x double>> [#uses=1]
Expand All @@ -66,14 +70,18 @@ define <4 x float> @test_x86_sse41_dpps(<4 x float> %a0, <4 x float> %a1) #0 {
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr @__msan_param_tls, align 8
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 16) to ptr), align 8
; CHECK-NEXT: call void @llvm.donothing()
; CHECK-NEXT: [[TMP3:%.*]] = or <4 x i32> [[TMP1]], [[TMP2]]
; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> <i1 false, i1 true, i1 true, i1 true>, <4 x i32> [[TMP3]], <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> [[TMP4]])
; CHECK-NEXT: [[_MSDPP:%.*]] = icmp eq i32 [[TMP5]], 0
; CHECK-NEXT: [[TMP6:%.*]] = select i1 [[_MSDPP]], <4 x i1> zeroinitializer, <4 x i1> <i1 false, i1 true, i1 true, i1 true>
; CHECK-NEXT: [[_MSDPP1:%.*]] = sext <4 x i1> [[TMP6]] to <4 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i32> [[TMP1]] to i128
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i128 [[TMP3]], 0
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP2]] to i128
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
; CHECK: 5:
; CHECK-NEXT: call void @__msan_warning_noreturn()
; CHECK-NEXT: unreachable
; CHECK: 6:
; CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.x86.sse41.dpps(<4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], i8 -18)
; CHECK-NEXT: store <4 x i32> [[_MSDPP1]], ptr @__msan_retval_tls, align 8
; CHECK-NEXT: store <4 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
; CHECK-NEXT: ret <4 x float> [[RES]]
;
%res = call <4 x float> @llvm.x86.sse41.dpps(<4 x float> %a0, <4 x float> %a1, i8 -18) ; <<4 x float>> [#uses=1]
Expand All @@ -92,7 +100,7 @@ define <4 x float> @test_x86_sse41_insertps(<4 x float> %a0, <4 x float> %a1) #0
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i32> [[TMP2]] to i128
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i128 [[TMP4]], 0
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP5:%.*]], label [[TMP6:%.*]], !prof [[PROF0]]
; CHECK: 5:
; CHECK-NEXT: call void @__msan_warning_noreturn()
; CHECK-NEXT: unreachable
Expand Down
30 changes: 17 additions & 13 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,19 +372,23 @@ class TypeConverter {
From the perspective of type conversion, the types of block arguments are a bit
special. Throughout the conversion process, blocks may move between regions of
different operations. Given this, the conversion of the types for blocks must be
done explicitly via a conversion pattern. To convert the types of block
arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
be invoked; `convertRegionTypes`. This hook uses a provided type converter to
apply type conversions to all blocks within a given region, and all blocks that
move into that region. As noted above, the conversions performed by this method
use the argument materialization hook on the `TypeConverter`. This hook also
takes an optional `TypeConverter::SignatureConversion` parameter that applies a
custom conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
AffineForOp, etc. To convert the signature of just the region entry block, and
not any other blocks within the region, the `applySignatureConversion` hook may
be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
can be built programmatically:
done explicitly via a conversion pattern.

To convert the types of block arguments within a Region, a custom hook on the
`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
uses a provided type converter to apply type conversions to all blocks of a
given region. As noted above, the conversions performed by this method use the
argument materialization hook on the `TypeConverter`. This hook also takes an
optional `TypeConverter::SignatureConversion` parameter that applies a custom
conversion to the entry block of the region. The types of the entry block
arguments are often tied semantically to the operation, e.g.,
`func::FuncOp`, `AffineForOp`, etc.

To convert the signature of just one given block, the
`applySignatureConversion` hook can be used.

A signature conversion, `TypeConverter::SignatureConversion`, can be built
programmatically:

```c++
class SignatureConversion {
Expand Down
49 changes: 25 additions & 24 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ class TypeConverter {
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
template <typename TargetType> TargetType convertType(Type t) const {
template <typename TargetType>
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}

Expand Down Expand Up @@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;

/// Apply a signature conversion to the entry block of the given region. This
/// replaces the entry block with a new block containing the updated
/// signature. The new entry block to the region is returned for convenience.
/// If no block argument types are changing, the entry original block will be
/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
///
/// If no block argument types are changing, the original block will be
/// left in place and returned.
///
/// If provided, `converter` will be used for any materializations.
/// A signature converison must be provided. (Type converters can construct
/// a signature conversion with `convertBlockSignature`.)
///
/// Optionally, a type converter can be provided to build materializations.
/// Note: If no type converter was provided or the type converter does not
/// specify any suitable argument/target materialization rules, the dialect
/// conversion may fail to legalize unresolved materializations.
Block *
applySignatureConversion(Region *region,
applySignatureConversion(Block *block,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter = nullptr);

/// Convert the types of block arguments within the given region. This
/// Apply a signature conversion to each block in the given region. This
/// replaces each block with a new block containing the updated signature. If
/// an updated signature would match the current signature, the respective
/// block is left in place as is.
/// block is left in place as is. (See `applySignatureConversion` for
/// details.) The new entry block of the region is returned.
///
/// SignatureConversions are computed with the specified type converter.
/// This function returns "failure" if the type converter failed to compute
/// a SignatureConversion for at least one block.
///
/// The entry block may have a special conversion if `entryConversion` is
/// provided. On success, the new entry block to the region is returned for
/// convenience. Otherwise, failure is returned.
/// Optionally, a special SignatureConversion can be specified for the entry
/// block. This is because the types of the entry block arguments are often
/// tied semantically to the operation.
FailureOr<Block *> convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);

/// Convert the types of block arguments within the given region except for
/// the entry region. This replaces each non-entry block with a new block
/// containing the updated signature. If an updated signature would match the
/// current signature, the respective block is left in place as is.
///
/// If special conversion behavior is needed for the non-entry blocks (for
/// example, we need to convert only a subset of a BB arguments), such
/// behavior can be specified in blockConversions.
LogicalResult convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions);

/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getRegion(),
body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
signatureConverter);

// Move the blocks from the forOp into the loopOp. This is the body of the
Expand Down
20 changes: 8 additions & 12 deletions mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
ConversionPatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
Region &region = op.getFunctionBody();
SmallVector<TypeConverter::SignatureConversion, 2> conversions;

for (Block &block : llvm::drop_begin(region, 1)) {
conversions.emplace_back(block.getNumArguments());
TypeConverter::SignatureConversion &back = conversions.back();
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
TypeConverter::SignatureConversion conversion(
/*numOrigInputs=*/block.getNumArguments());

for (BlockArgument blockArgument : block.getArguments()) {
int idx = blockArgument.getArgNumber();

if (blockArgsToDetensor.count(blockArgument))
back.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
conversion.addInputs(idx, {getTypeConverter()->convertType(
block.getArgumentTypes()[idx])});
else
back.addInputs(idx, {block.getArgumentTypes()[idx]});
conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
}
}

if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
conversions))) {
rewriter.cancelOpModification(op);
return failure();
rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}

rewriter.finalizeOpModification(op);
Expand Down
123 changes: 27 additions & 96 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//

/// Attempt to convert the signature of the given block, if successful a new
/// block is returned containing the new arguments. Returns `block` if it did
/// not require conversion.
FailureOr<Block *> convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion = nullptr);

/// Convert the types of non-entry block arguments within the given region.
LogicalResult convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});

/// Apply a signature conversion on the given region, using `converter` for
/// materializations if not null.
Block *
applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter);

/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
Expand Down Expand Up @@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
// Type Conversion

FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion *conversion) {
if (conversion)
return applySignatureConversion(rewriter, block, converter, *conversion);

// If a converter wasn't provided, and the block wasn't already converted,
// there is nothing we can do.
if (!converter)
return failure();

// Try to convert the signature for the block with the provided converter.
if (auto conversion = converter->convertBlockSignature(block))
return applySignatureConversion(rewriter, block, converter, *conversion);
return failure();
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Region *region,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
if (!region->empty())
return *convertBlockSignature(rewriter, &region->front(), converter,
&conversion);
return nullptr;
}

FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
Expand All @@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (region->empty())
return nullptr;

if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
return failure();

FailureOr<Block *> newEntry = convertBlockSignature(
rewriter, &region->front(), &converter, entryConversion);
return newEntry;
}

LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
regionToConverter[region] = &converter;
if (region->empty())
return success();

// Convert the arguments of each block within the region.
int blockIdx = 0;
assert((blockConversions.empty() ||
blockConversions.size() == region->getBlocks().size() - 1) &&
"expected either to provide no SignatureConversions at all or to "
"provide a SignatureConversion for each non-entry block");

// Convert the arguments of each non-entry block within the region.
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
TypeConverter::SignatureConversion *blockConversion =
blockConversions.empty()
? nullptr
: const_cast<TypeConverter::SignatureConversion *>(
&blockConversions[blockIdx++]);

if (failed(convertBlockSignature(rewriter, &block, &converter,
blockConversion)))
// Compute the signature for the block with the provided converter.
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&block);
if (!conversion)
return failure();
}
return success();
// Convert the block with the computed signature.
applySignatureConversion(rewriter, &block, &converter, *conversion);
}

// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
return applySignatureConversion(rewriter, &region->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(&region->front());
if (!conversion)
return failure();
return applySignatureConversion(rewriter, &region->front(), &converter,
*conversion);
}

Block *ConversionPatternRewriterImpl::applySignatureConversion(
Expand Down Expand Up @@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
}

Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion,
Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->applySignatureConversion(*this, region, conversion, converter);
return impl->applySignatureConversion(*this, block, converter, conversion);
}

FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
Expand All @@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}

LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
Region *region, const TypeConverter &converter,
ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
return impl->convertNonEntryRegionTypes(*this, region, converter,
blockConversions);
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
Expand Down Expand Up @@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
impl.applySignatureConversion(rewriter, block, converter, *conversion);
continue;
}

Expand Down
5 changes: 3 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
rewriter.modifyOpInPlace(
op, [&] { rewriter.applySignatureConversion(&region, result); });
rewriter.modifyOpInPlace(op, [&] {
rewriter.applySignatureConversion(&region.front(), result);
});
return success();
}

Expand Down