diff --git a/src/analyze/TypeInferencePass.cpp b/src/analyze/TypeInferencePass.cpp index 01ed4f827..b83d06366 100644 --- a/src/analyze/TypeInferencePass.cpp +++ b/src/analyze/TypeInferencePass.cpp @@ -26,7 +26,7 @@ #include "TypeInferencePass.h" #include "../Logger.h" -#include "../Namespace.h" +#include "../analyze/SymbolResolverPass.h" #include "../ast/RecursiveVisitor.h" #include "../casm-ir/src/Builtin.h" @@ -47,7 +47,7 @@ static libpass::PassRegistration< TypeInferencePass > PASS( class TypeCheckVisitor final : public RecursiveVisitor { public: - TypeCheckVisitor( Logger& log ); + TypeCheckVisitor( Logger& log, Namespace& symboltable ); void visit( BasicType& node ) override; void visit( ComposedType& node ) override; @@ -58,6 +58,7 @@ class TypeCheckVisitor final : public RecursiveVisitor private: Logger& m_log; u64 m_err; + Namespace& m_symboltable; }; static const std::unordered_map< std::string, libcasm_ir::Type::Ptr > basicTypes @@ -68,13 +69,12 @@ static const std::unordered_map< std::string, libcasm_ir::Type::Ptr > basicTypes { "Bit", libstdhl::get< libcasm_ir::BitType >( 1 ) }, { "String", libstdhl::get< libcasm_ir::StringType >() }, { "Floating", libstdhl::get< libcasm_ir::FloatingType >() }, - // enumeration - // agent }; -TypeCheckVisitor::TypeCheckVisitor( Logger& log ) +TypeCheckVisitor::TypeCheckVisitor( Logger& log, Namespace& symboltable ) : m_log( log ) , m_err( 0 ) +, m_symboltable( symboltable ) { } @@ -91,16 +91,44 @@ void TypeCheckVisitor::visit( BasicType& node ) } else if( name.compare( "Agent" ) == 0 ) { - // TODO: handle the agent case + m_log.debug( "TODO: FIXME: handle 'Agent' case!" ); // TODO: PPA: } else { - // m_err++; // TODO: enable this line! - m_log.error( { node.sourceLocation() }, - "unknown type '" + name + "' found" ); + try + { + auto symbol = m_symboltable.find( node ); + + assert( symbol.targetType() + == CallExpression::TargetType::ENUMERATION ); + + auto& definition = static_cast< EnumerationDefinition& >( + symbol.definition() ); + + if( not definition.type() ) + { + m_log.debug( "enum IR type not set yet" ); + + auto kind + = libstdhl::make< libcasm_ir::Enumeration >( name ); + for( auto e : *definition.enumerators() ) + { + kind->add( e->identifier() ); + } - // TODO: it could still be a enumeration type etc., check this in - // the symbol table! + auto type + = libstdhl::make< libcasm_ir::EnumerationType >( kind ); + definition.setType( type ); + } + + node.setType( definition.type() ); + } + catch( const std::domain_error& e ) + { + // m_err++; // TODO: enable this line! + m_log.error( { node.sourceLocation() }, + "unknown type '" + name + "' found" ); + } } } @@ -118,14 +146,22 @@ void TypeCheckVisitor::visit( FixedSizedType& node ) if( not node.type() ) { const auto& name = node.name()->identifier(); - auto const& expr = *node.size(); + const auto& expr = *node.size(); if( name.compare( "Bit" ) == 0 ) { if( expr.id() == Node::ID::VALUE_ATOM and expr.type()->isInteger() ) { - // TODO: handle bit vector case, check if integer > 0 and set - // the IR::Bit(n) type + const auto& atom = static_cast< const ValueAtom& >( expr ); + + const auto value + = std::static_pointer_cast< libcasm_ir::IntegerConstant >( + atom.value() ); + + auto type + = libstdhl::get< libcasm_ir::BitType >( value->value() ); + + node.setType( type ); } else { @@ -225,8 +261,9 @@ u1 TypeInferencePass::run( libpass::PassResult& pr ) const auto data = pr.result< SymbolResolverPass >(); const auto specification = data->specification(); + const auto symboltable = data->symboltable(); - TypeCheckVisitor typChkVisitor( log ); + TypeCheckVisitor typChkVisitor( log, *symboltable ); specification->accept( typChkVisitor ); const auto typChkErr = typChkVisitor.errors(); diff --git a/src/analyze/TypeInferencePass.h b/src/analyze/TypeInferencePass.h index f384c34e5..102054695 100644 --- a/src/analyze/TypeInferencePass.h +++ b/src/analyze/TypeInferencePass.h @@ -26,10 +26,7 @@ #ifndef _LIB_CASMFE_TYPE_INFERENCE_PASS_H_ #define _LIB_CASMFE_TYPE_INFERENCE_PASS_H_ -#include "../analyze/SymbolResolverPass.h" - -#include "../ast/RecursiveVisitor.h" -#include "../ast/Specification.h" +#include "../transform/SourceToAstPass.h" namespace libcasm_fe { @@ -45,7 +42,7 @@ namespace libcasm_fe bool run( libpass::PassResult& pr ) override; - using Data = SymbolResolverPass::Data; + using Data = SourceToAstPass::Data; }; }