diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h index 5bbbdf23b257e..5657303b0a1f2 100644 --- a/llvm/include/llvm/ADT/TypeSwitch.h +++ b/llvm/include/llvm/ADT/TypeSwitch.h @@ -17,6 +17,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include namespace llvm { @@ -117,11 +118,16 @@ class TypeSwitch : public detail::TypeSwitchBase, T> { return defaultResult; } - [[nodiscard]] operator ResultT() { - assert(result && "Fell off the end of a type-switch"); - return std::move(*result); + /// Declare default as unreachable, making sure that all cases were handled. + [[nodiscard]] ResultT DefaultUnreachable( + const char *message = "Fell off the end of a type-switch") { + if (result) + return std::move(*result); + llvm_unreachable(message); } + [[nodiscard]] operator ResultT() { return DefaultUnreachable(); } + private: /// The pointer to the result of this switch statement, once known, /// null before that. @@ -158,6 +164,13 @@ class TypeSwitch defaultFn(this->value); } + /// Declare default as unreachable, making sure that all cases were handled. + void DefaultUnreachable( + const char *message = "Fell off the end of a type-switch") { + if (!foundMatch) + llvm_unreachable(message); + } + private: /// A flag detailing if we have already found a match. bool foundMatch = false; diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp index c54b7987edf7e..a7d934265c5f0 100644 --- a/llvm/unittests/ADT/TypeSwitchTest.cpp +++ b/llvm/unittests/ADT/TypeSwitchTest.cpp @@ -114,3 +114,31 @@ TEST(TypeSwitchTest, CasesOptional) { EXPECT_EQ(std::nullopt, translate(DerivedC())); EXPECT_EQ(-1, translate(DerivedD())); } + +TEST(TypeSwitchTest, DefaultUnreachableWithValue) { + auto translate = [](auto value) { + return TypeSwitch(&value) + .Case([](DerivedA *) { return 0; }) + .DefaultUnreachable("Unhandled type"); + }; + EXPECT_EQ(0, translate(DerivedA())); + +#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) + EXPECT_DEATH((void)translate(DerivedD()), "Unhandled type"); +#endif +} + +TEST(TypeSwitchTest, DefaultUnreachableWithVoid) { + auto translate = [](auto value) { + int result = -1; + TypeSwitch(&value) + .Case([&result](DerivedA *) { result = 0; }) + .DefaultUnreachable("Unhandled type"); + return result; + }; + EXPECT_EQ(0, translate(DerivedA())); + +#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG) + EXPECT_DEATH((void)translate(DerivedD()), "Unhandled type"); +#endif +}