Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add FIRRTLTypeSwitch #5456

Merged
merged 10 commits into from
Jun 29, 2023
95 changes: 95 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "circt/Support/LLVM.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/TypeSwitch.h"

namespace circt {
namespace firrtl {
Expand Down Expand Up @@ -448,6 +449,100 @@ type_dyn_cast_or_null(Type type) { // NOLINT(readability-identifier-naming)
return {};
}

//===--------------------------------------------------------------------===//
// Type alias aware TypeSwitch.
//===--------------------------------------------------------------------===//

/// This class implements the same functionality as TypeSwitch except that
/// it uses firrtl::type_dyn_cast for dynamic cast. llvm::TypeSwitch is not
/// customizable so this class currently duplicates the code.
template <typename T, typename ResultT = void>
class FIRRTLTypeSwitch
: public llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, ResultT>, T> {
public:
using BaseT = llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, ResultT>, T>;
using BaseT::BaseT;
using BaseT::Case;
FIRRTLTypeSwitch(FIRRTLTypeSwitch &&other) = default;

/// Add a case on the given type.
template <typename CaseT, typename CallableT>
FIRRTLTypeSwitch<T, ResultT> &
Case(CallableT &&caseFn) { // NOLINT(readability-identifier-naming)
if (result)
return *this;

// Check to see if CaseT applies to 'value'. Use `type_dyn_cast` here.
if (auto caseValue = circt::firrtl::type_dyn_cast<CaseT>(this->value))
result.emplace(caseFn(caseValue));
return *this;
}

/// As a default, invoke the given callable within the root value.
template <typename CallableT>
[[nodiscard]] ResultT
Default(CallableT &&defaultFn) { // NOLINT(readability-identifier-naming)
if (result)
return std::move(*result);
return defaultFn(this->value);
}

/// As a default, return the given value.
[[nodiscard]] ResultT
Default(ResultT defaultResult) { // NOLINT(readability-identifier-naming)
if (result)
return std::move(*result);
return defaultResult;
}

[[nodiscard]] operator ResultT() {
assert(result && "Fell off the end of a type-switch");
return std::move(*result);
}

private:
/// The pointer to the result of this switch statement, once known,
/// null before that.
std::optional<ResultT> result;
};

/// Specialization of FIRRTLTypeSwitch for void returning callables.
template <typename T>
class FIRRTLTypeSwitch<T, void>
: public llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, void>, T> {
public:
using BaseT = llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, void>, T>;
using BaseT::BaseT;
using BaseT::Case;
FIRRTLTypeSwitch(FIRRTLTypeSwitch &&other) = default;

/// Add a case on the given type.
template <typename CaseT, typename CallableT>
FIRRTLTypeSwitch<T, void> &
Case(CallableT &&caseFn) { // NOLINT(readability-identifier-naming)
if (foundMatch)
return *this;

// Check to see if any of the types apply to 'value'.
if (auto caseValue = circt::firrtl::type_dyn_cast<CaseT>(this->value)) {
caseFn(caseValue);
foundMatch = true;
}
return *this;
}

/// As a default, invoke the given callable within the root value.
template <typename CallableT>
void Default(CallableT &&defaultFn) { // NOLINT(readability-identifier-naming)
if (!foundMatch)
defaultFn(this->value);
}

private:
/// A flag detailing if we have already found a match.
bool foundMatch = false;
};

} // namespace firrtl
} // namespace circt

Expand Down