diff --git a/src/analyze/SymbolResolverPass.cpp b/src/analyze/SymbolResolverPass.cpp index 4ba2e1782..de65098df 100644 --- a/src/analyze/SymbolResolverPass.cpp +++ b/src/analyze/SymbolResolverPass.cpp @@ -26,6 +26,7 @@ #include "SymbolResolverPass.h" #include "../Logger.h" +#include "../Namespace.h" #include "../ast/RecursiveVisitor.h" #include "../casm-ir/src/Builtin.h" @@ -40,12 +41,14 @@ static libpass::PassRegistration< SymbolResolverPass > PASS( "resolves all AST identifiers and creates a symbol table", "ast-symtbl", 0 ); -class SymbolResolverVisitor final : public RecursiveVisitor +// +// SymbolTableVisitor +// + +class SymbolTableVisitor final : public RecursiveVisitor { public: - SymbolResolverVisitor( Logger& log ); - - u64 errors( void ) const; + SymbolTableVisitor( Logger& m_log ); void visit( Specification& node ) override; @@ -54,246 +57,196 @@ class SymbolResolverVisitor final : public RecursiveVisitor void visit( RuleDefinition& node ) override; void visit( EnumerationDefinition& node ) override; - void visit( DirectCallExpression& node ) override; - - void visit( UniversalQuantifierExpression& node ) override; - void visit( ExistentialQuantifierExpression& node ) override; + u64 errors( void ) const; - void visit( LetRule& node ) override; - void visit( ForallRule& node ) override; + Namespace::Ptr symboltable( void ) const; private: - void registerSymbol( const IdentifierNode& node, - const CallExpression::TargetType targetType, - const std::size_t arity = 0 ); - - void unregisterSymbol( - const IdentifierNode& node, const std::size_t arity = 0 ); - - std::string key( - const IdentifierNode& node, const std::size_t arity ) const; - - Logger& log; - - u64 err; - - std::unordered_map< std::string, CallExpression::TargetType > m_symbolTable; - - std::unordered_map< std::string, DirectCallExpression* > m_late_resolve; + Logger& m_log; + u64 m_err; + Namespace::Ptr m_symboltable; }; -SymbolResolverVisitor::SymbolResolverVisitor( Logger& log ) -: log( log ) -, err( 0 ) +SymbolTableVisitor::SymbolTableVisitor( Logger& log ) +: m_log( log ) +, m_err( 0 ) { } -void SymbolResolverVisitor::visit( Specification& node ) +void SymbolTableVisitor::visit( Specification& node ) { + m_symboltable = libstdhl::make< Namespace >(); RecursiveVisitor::visit( node ); - - log.debug( "symbol_table:" ); - for( const auto& v : m_symbolTable ) - { - const auto& identifier = v.first; - const auto targetType = v.second; - - log.debug( " '" + identifier + "' --> " - + CallExpression::targetTypeString( targetType ) ); - } - - for( const auto& v : m_late_resolve ) - { - const auto& identifier = v.first; - const auto node = v.second; - - auto result = m_symbolTable.find( identifier ); - if( result != m_symbolTable.end() ) - { - node->setTargetType( result->second ); - } - else - { - err++; - log.error( { node->sourceLocation() }, - "symbol '" + identifier + "' cannot be resolved" ); - } - } } -void SymbolResolverVisitor::visit( FunctionDefinition& node ) +void SymbolTableVisitor::visit( FunctionDefinition& node ) { - registerSymbol( *node.identifier(), CallExpression::TargetType::FUNCTION, - node.argumentTypes()->size() ); + m_err += m_symboltable->registerSymbol( m_log, node ); RecursiveVisitor::visit( node ); } -void SymbolResolverVisitor::visit( DerivedDefinition& node ) +void SymbolTableVisitor::visit( DerivedDefinition& node ) { - registerSymbol( *node.identifier(), CallExpression::TargetType::DERIVED, - node.arguments()->size() ); + m_err += m_symboltable->registerSymbol( m_log, node ); + RecursiveVisitor::visit( node ); +} - for( auto e : *node.arguments() ) - { - registerSymbol( - *e->identifier(), CallExpression::TargetType::VARIABLE ); - } +void SymbolTableVisitor::visit( RuleDefinition& node ) +{ + m_err += m_symboltable->registerSymbol( m_log, node ); + RecursiveVisitor::visit( node ); +} +void SymbolTableVisitor::visit( EnumerationDefinition& node ) +{ + m_err += m_symboltable->registerSymbol( m_log, node ); RecursiveVisitor::visit( node ); +} - for( auto e : *node.arguments() ) - { - unregisterSymbol( *e->identifier() ); - } +Namespace::Ptr SymbolTableVisitor::symboltable( void ) const +{ + return m_symboltable; } -void SymbolResolverVisitor::visit( RuleDefinition& node ) +u64 SymbolTableVisitor::errors( void ) const { - registerSymbol( *node.identifier(), CallExpression::TargetType::RULE, - node.arguments()->size() ); + return m_err; +} - for( auto e : *node.arguments() ) - { - registerSymbol( - *e->identifier(), CallExpression::TargetType::VARIABLE ); - } +// +// SymbolResolveVisitor +// - RecursiveVisitor::visit( node ); +class SymbolResolveVisitor final : public RecursiveVisitor +{ + public: + SymbolResolveVisitor( Logger& log, Namespace& symboltable ); - for( auto e : *node.arguments() ) - { - unregisterSymbol( *e->identifier() ); - } -} + void visit( DirectCallExpression& node ) override; -void SymbolResolverVisitor::visit( EnumerationDefinition& node ) -{ - registerSymbol( - *node.identifier(), CallExpression::TargetType::ENUMERATION ); + void visit( UniversalQuantifierExpression& node ) override; + void visit( ExistentialQuantifierExpression& node ) override; - for( auto e : *node.enumerators() ) - { - registerSymbol( *e, CallExpression::TargetType::CONSTANT ); - } + void visit( LetRule& node ) override; + void visit( ForallRule& node ) override; - RecursiveVisitor::visit( node ); + u64 errors( void ) const; + + private: + void push( const VariableDefinition& identifier ); + void pop( const VariableDefinition& identifier ); + + Logger& m_log; + u64 m_err; + Namespace& m_symboltable; + std::unordered_set< std::string > m_variables; +}; + +SymbolResolveVisitor::SymbolResolveVisitor( + Logger& log, Namespace& symboltable ) +: m_log( log ) +, m_err( 0 ) +, m_symboltable( symboltable ) +{ } -void SymbolResolverVisitor::visit( DirectCallExpression& node ) +void SymbolResolveVisitor::visit( DirectCallExpression& node ) { - const auto arity = node.arguments()->size(); - const auto _key = key( *node.identifier(), arity ); + const auto targetType = m_symboltable.find( node ); + + const auto name = node.identifier()->identifier(); - auto result = m_symbolTable.find( _key ); - if( result != m_symbolTable.end() ) + if( targetType != CallExpression::TargetType::UNKNOWN ) { - node.setTargetType( result->second ); + node.setTargetType( targetType ); } else { - const auto name = node.identifier()->identifier(); + const auto arity = node.arguments()->size(); if( libcasm_ir::Builtin::available( name, arity ) ) { - registerSymbol( *node.identifier(), + m_err += m_symboltable.registerSymbol( m_log, *node.identifier(), CallExpression::TargetType::BUILTIN, arity ); node.setTargetType( CallExpression::TargetType::BUILTIN ); } + else if( m_variables.find( name ) != m_variables.end() ) + { + node.setTargetType( CallExpression::TargetType::VARIABLE ); + } else { - log.debug( "memorize '" + _key + "'" ); - m_late_resolve.emplace( _key, &node ); + m_err++; + m_log.error( { node.sourceLocation() }, + "invalid symbol '" + name + "' found" ); } } - log.debug( "call: " + _key + "{ " + node.targetTypeName() + " }" ); + m_log.debug( "call: " + name + "{ " + node.targetTypeName() + " }" ); RecursiveVisitor::visit( node ); } -void SymbolResolverVisitor::visit( UniversalQuantifierExpression& node ) +void SymbolResolveVisitor::visit( UniversalQuantifierExpression& node ) { - const auto& id = *node.predicateVariable()->identifier(); - - registerSymbol( id, CallExpression::TargetType::VARIABLE ); + push( *node.predicateVariable() ); RecursiveVisitor::visit( node ); - unregisterSymbol( id ); + pop( *node.predicateVariable() ); } -void SymbolResolverVisitor::visit( ExistentialQuantifierExpression& node ) +void SymbolResolveVisitor::visit( ExistentialQuantifierExpression& node ) { - const auto& id = *node.predicateVariable()->identifier(); - - registerSymbol( id, CallExpression::TargetType::VARIABLE ); + push( *node.predicateVariable() ); RecursiveVisitor::visit( node ); - unregisterSymbol( id ); + pop( *node.predicateVariable() ); } -void SymbolResolverVisitor::visit( LetRule& node ) +void SymbolResolveVisitor::visit( LetRule& node ) { - const auto& id = *node.variable()->identifier(); - - registerSymbol( id, CallExpression::TargetType::VARIABLE ); + push( *node.variable() ); RecursiveVisitor::visit( node ); - unregisterSymbol( id ); + pop( *node.variable() ); } -void SymbolResolverVisitor::visit( ForallRule& node ) +void SymbolResolveVisitor::visit( ForallRule& node ) { - const auto& id = *node.variable()->identifier(); - - registerSymbol( id, CallExpression::TargetType::VARIABLE ); + push( *node.variable() ); RecursiveVisitor::visit( node ); - unregisterSymbol( id ); + pop( *node.variable() ); } -void SymbolResolverVisitor::registerSymbol( const IdentifierNode& node, - const CallExpression::TargetType targetType, const std::size_t arity ) +void SymbolResolveVisitor::push( const VariableDefinition& node ) { - const auto _key = key( node, arity ); - - auto result = m_symbolTable.emplace( _key, targetType ); + const auto& name = node.identifier()->identifier(); + auto result = m_variables.emplace( name ); if( not result.second ) { - err++; - log.error( { node.sourceLocation() }, - "symbol '" + result.first->first + "' already defined as '" - + CallExpression::targetTypeString( result.first->second ) - + "'" ); + m_err++; + m_log.error( { node.sourceLocation() }, + "symbol '" + name + "' already defined" ); } - - log.debug( "registered new symbol '" + result.first->first + "' as '" - + CallExpression::targetTypeString( result.first->second ) - + "'" ); } -void SymbolResolverVisitor::unregisterSymbol( - const IdentifierNode& node, const std::size_t arity ) +void SymbolResolveVisitor::pop( const VariableDefinition& node ) { - const auto _key = key( node, arity ); + const auto& name = node.identifier()->identifier(); - if( m_symbolTable.erase( _key ) != 1 ) + if( m_variables.erase( name ) != 1 ) { - throw std::domain_error( - "symbol '" + _key + "' was erased more than once" ); + assert( !" internal error! " ); } - - log.debug( "unregistered symbol '" + _key + "'" ); } -std::string SymbolResolverVisitor::key( - const IdentifierNode& node, const std::size_t arity ) const +u64 SymbolResolveVisitor::errors( void ) const { - const auto identifier = node.identifier(); - return std::to_string( arity ) + "@" + identifier; + return m_err; } -u64 SymbolResolverVisitor::errors( void ) const -{ - return err; -} +// +// SymbolResolverPass +// void SymbolResolverPass::usage( libpass::PassUsage& pu ) { @@ -307,15 +260,31 @@ u1 SymbolResolverPass::run( libpass::PassResult& pr ) const auto data = pr.result< AttributionPass >(); const auto specification = data->specification(); - SymbolResolverVisitor visitor( log ); - specification->accept( visitor ); + SymbolTableVisitor symTblVisitor( log ); + specification->accept( symTblVisitor ); - if( visitor.errors() ) + const auto symTblErr = symTblVisitor.errors(); + if( symTblErr ) { - log.debug( "%lu error(s)", visitor.errors() ); + log.debug( + "found %lu error(s) during symbol table creation", symTblErr ); return false; } + SymbolResolveVisitor symResVisitor( log, *symTblVisitor.symboltable() ); + specification->accept( symResVisitor ); + + const auto symResErr = symResVisitor.errors(); + if( symResErr ) + { + log.debug( "found %lu error(s) during symbol resolving", symResErr ); + return false; + } + +#ifndef NDEBUG + log.debug( "symbol table = \n" + symTblVisitor.symboltable()->dump() ); +#endif + pr.setResult< SymbolResolverPass >( libstdhl::make< Data >( specification ) );