Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Add FIRRTLTypeSwitch #5456

Merged
merged 10 commits into from
Jun 29, 2023
91 changes: 91 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "circt/Support/LLVM.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/TypeSwitch.h"

namespace circt {
namespace firrtl {
Expand Down Expand Up @@ -448,6 +449,96 @@ type_dyn_cast_or_null(Type type) { // NOLINT(readability-identifier-naming)
return {};
}

//===--------------------------------------------------------------------===//
// Type alias aware TypeSwitch.
//===--------------------------------------------------------------------===//

/// This class implements the same functionality as TypeSwitch except that
/// it uses firrtl::type_dyn_cast for dynamic cast. llvm::TypeSwitch is not
/// customizable so this class currently duplicates the code.
template <typename T, typename ResultT = void>
class FIRRTLTypeSwitch
: public llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, ResultT>, T> {
public:
using BaseT = llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, ResultT>, T>;
using BaseT::BaseT;
using BaseT::Case;
FIRRTLTypeSwitch(FIRRTLTypeSwitch &&other) = default;

/// Add a case on the given type.
template <typename CaseT, typename CallableT>
FIRRTLTypeSwitch<T, ResultT> &Case(CallableT &&caseFn) {
uenoku marked this conversation as resolved.
Show resolved Hide resolved
if (result)
return *this;

// Check to see if CaseT applies to 'value'. Use `type_dyn_cast` here.
if (auto caseValue = circt::firrtl::type_dyn_cast<CaseT>(this->value))
result.emplace(caseFn(caseValue));
return *this;
}

/// As a default, invoke the given callable within the root value.
template <typename CallableT>
[[nodiscard]] ResultT Default(CallableT &&defaultFn) {
uenoku marked this conversation as resolved.
Show resolved Hide resolved
if (result)
return std::move(*result);
return defaultFn(this->value);
}

/// As a default, return the given value.
[[nodiscard]] ResultT Default(ResultT defaultResult) {
uenoku marked this conversation as resolved.
Show resolved Hide resolved
if (result)
return std::move(*result);
return defaultResult;
}

[[nodiscard]] operator ResultT() {
assert(result && "Fell off the end of a type-switch");
return std::move(*result);
}

private:
/// The pointer to the result of this switch statement, once known,
/// null before that.
std::optional<ResultT> result;
};

/// Specialization of FIRRTLTypeSwitch for void returning callables.
template <typename T>
class FIRRTLTypeSwitch<T, void>
: public llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, void>, T> {
public:
using BaseT = llvm::detail::TypeSwitchBase<FIRRTLTypeSwitch<T, void>, T>;
using BaseT::BaseT;
using BaseT::Case;
FIRRTLTypeSwitch(FIRRTLTypeSwitch &&other) = default;

/// Add a case on the given type.
template <typename CaseT, typename CallableT>
FIRRTLTypeSwitch<T, void> &Case(CallableT &&caseFn) {
uenoku marked this conversation as resolved.
Show resolved Hide resolved
if (foundMatch)
return *this;

// Check to see if any of the types apply to 'value'.
if (auto caseValue = circt::firrtl::type_dyn_cast<CaseT>(this->value)) {
caseFn(caseValue);
foundMatch = true;
}
return *this;
}

/// As a default, invoke the given callable within the root value.
template <typename CallableT>
void Default(CallableT &&defaultFn) {
uenoku marked this conversation as resolved.
Show resolved Hide resolved
if (!foundMatch)
defaultFn(this->value);
}

private:
/// A flag detailing if we have already found a match.
bool foundMatch = false;
};

} // namespace firrtl
} // namespace circt

Expand Down