Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIRRTL] Implement alias-aware type casts #5417

Merged
merged 5 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class RefType;
class PropertyType;
class StringType;
class BigIntType;
class BaseTypeAliasType;

/// A collection of bits indicating the recursive properties of a type.
struct RecursiveTypeProperties {
Expand Down Expand Up @@ -367,4 +368,87 @@ struct DenseMapInfo<circt::firrtl::FIRRTLType> {

} // namespace llvm

namespace circt {
namespace firrtl {
//===--------------------------------------------------------------------===//
// Utility for type aliases
//===--------------------------------------------------------------------===//

/// A struct to check if there is a type derived from FIRRTLBaseType.
/// `ContainBaseSubTypes<BaseTy>::value` returns true if `BaseTy` is a derived
uenoku marked this conversation as resolved.
Show resolved Hide resolved
/// from `FIRRTLBaseType` and not `FIRRTLBaseType` itself.
template <typename head, typename... tail>
struct ContainBaseSubTypes {
static constexpr bool value =
ContainBaseSubTypes<head>::value || ContainBaseSubTypes<tail...>::value;
};

template <typename BaseTy>
struct ContainBaseSubTypes<BaseTy> {
static constexpr bool value =
std::is_base_of<FIRRTLBaseType, BaseTy>::value &&
!std::is_same_v<FIRRTLBaseType, BaseTy>;
};

template <typename... BaseTy>
bool type_isa(Type type) { // NOLINT(readability-identifier-naming)
// First check if the type is the requested type.
if (isa<BaseTy...>(type))
return true;

// If the requested type is a subtype of FIRRTLBaseType, then check if it is a
// type alias wrapping the requested type.
if constexpr (ContainBaseSubTypes<BaseTy...>::value) {
if (auto alias = dyn_cast<BaseTypeAliasType>(type))
return type_isa<BaseTy...>(alias.getInnerType());
}

return false;
}

// type_isa for a nullable argument.
template <typename... BaseTy>
bool type_isa_and_nonnull(Type type) { // NOLINT(readability-identifier-naming)
if (!type)
return false;
return type_isa<BaseTy...>(type);
}

template <typename BaseTy>
BaseTy type_cast(Type type) { // NOLINT(readability-identifier-naming)
assert(type_isa<BaseTy>(type) && "type must convert to requested type");

// If the type is the requested type, return it.
if (isa<BaseTy>(type))
return cast<BaseTy>(type);

// Otherwise, it must be a type alias wrapping the requested type.
if constexpr (ContainBaseSubTypes<BaseTy>::value) {
if (auto alias = dyn_cast<BaseTypeAliasType>(type))
return type_cast<BaseTy>(alias.getInnerType());
}

// Otherwise, it should fail. `cast` should cause a better assertion failure,
// so just use it.
return cast<BaseTy>(type);
}

template <typename BaseTy>
BaseTy type_dyn_cast(Type type) { // NOLINT(readability-identifier-naming)
if (type_isa<BaseTy>(type))
return type_cast<BaseTy>(type);
return {};
}

template <typename BaseTy>
BaseTy
type_dyn_cast_or_null(Type type) { // NOLINT(readability-identifier-naming)
if (type_isa_and_nonnull<BaseTy>(type))
return type_cast<BaseTy>(type);
return {};
}

} // namespace firrtl
} // namespace circt

#endif // CIRCT_DIALECT_FIRRTL_TYPES_H
23 changes: 23 additions & 0 deletions unittests/Dialect/FIRRTL/TypesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,27 @@ TEST(TypesTest, AnalogContainsAnalog) {
ASSERT_TRUE(AnalogType::get(&context).containsAnalog());
}

TEST(TypesTest, TypeAliasCast) {
MLIRContext context;
context.loadDialect<FIRRTLDialect>();
// Check containBaseSubTypes.
static_assert(!ContainBaseSubTypes<FIRRTLType>::value);
// Return false for FIRRTLBaseType.
static_assert(!ContainBaseSubTypes<FIRRTLBaseType>::value);
static_assert(!ContainBaseSubTypes<StringType>::value);
static_assert(ContainBaseSubTypes<FVectorType>::value);
static_assert(ContainBaseSubTypes<UIntType, StringType>::value);
AnalogType analog = AnalogType::get(&context);
BaseTypeAliasType alias1 =
BaseTypeAliasType::get(StringAttr::get(&context, "foo"), analog);
BaseTypeAliasType alias2 =
BaseTypeAliasType::get(StringAttr::get(&context, "bar"), alias1);
ASSERT_TRUE(!type_isa<FVectorType>(analog));
ASSERT_TRUE(type_isa<AnalogType>(analog));
ASSERT_TRUE(type_isa<AnalogType>(alias1));
ASSERT_TRUE(type_isa<AnalogType>(alias2));
ASSERT_TRUE(!type_isa<FVectorType>(alias2));
ASSERT_TRUE((type_isa<AnalogType, StringType>(alias2)));
}

} // namespace