diff --git a/src/analyze/TypeInferencePass.cpp b/src/analyze/TypeInferencePass.cpp index 27772cae3..0033f3c80 100644 --- a/src/analyze/TypeInferencePass.cpp +++ b/src/analyze/TypeInferencePass.cpp @@ -657,6 +657,25 @@ void TypeInferenceVisitor::visit( RuleDefinition& node ) void TypeInferenceVisitor::visit( ValueAtom& node ) { RecursiveVisitor::visit( node ); + + auto result = m_resultTypes.find( &node ); + assert( node.value() and node.type() ); + + std::vector< libcasm_ir::Type::ID > ty = { node.type()->id() }; + std::vector< libcasm_ir::Type::ID > tmp = {}; + + if( result != m_resultTypes.end() ) + { + std::set_intersection( result->second.begin(), result->second.end(), + ty.begin(), ty.end(), std::back_inserter( tmp ) ); + } + else + { + std::set_intersection( tmp.begin(), tmp.end(), ty.begin(), ty.end(), + std::back_inserter( tmp ) ); + } + + m_resultTypes[&node ] = std::move( tmp ); } void TypeInferenceVisitor::visit( ReferenceAtom& node ) @@ -710,7 +729,8 @@ void TypeInferenceVisitor::visit( ReferenceAtom& node ) { // TODO - // node.setReferenceType( ReferenceAtom::ReferenceType::BUILTIN + // node.setReferenceType( + // ReferenceAtom::ReferenceType::BUILTIN // ); // node.setBuiltinId( annotation.id() ); break; @@ -804,44 +824,31 @@ void TypeInferenceVisitor::visit( DirectCallExpression& node ) } case CallExpression::TargetType::BUILTIN: { - assert( not node.type() ); + const auto description = "built-in '" + path.path() + "'"; + inference( + description, annotation, node, node.arguments()->data() ); - try + std::vector< libcasm_ir::Type::Ptr > argTypeList; + for( auto argumentType : *node.arguments() ) { - const auto description = "built-in '" + path.path() + "'"; - inference( - description, annotation, node, node.arguments()->data() ); - - if( node.type() ) + if( not argumentType->type() ) { - std::vector< libcasm_ir::Type::Ptr > argTypeList; - for( auto argumentType : *node.arguments() ) - { - if( not argumentType->type() ) - { - m_log.debug( { argumentType->sourceLocation() }, - "TODO: '" + path.path() - + "' has a non-typed argument(s)" ); - return; - } - - argTypeList.emplace_back( - argumentType->type()->ptr_result() ); - } - - const auto type - = libstdhl::make< libcasm_ir::RelationType >( - node.type(), argTypeList ); - - node.setType( type ); + m_log.debug( { argumentType->sourceLocation() }, + "TODO: '" + path.path() + + "' has a non-typed argument(s)" ); + return; } + + argTypeList.emplace_back( argumentType->type()->ptr_result() ); } - catch( const std::domain_error& e ) + + if( node.type() ) { - m_log.error( { node.sourceLocation() }, e.what() ); - return; - } + const auto type = libstdhl::make< libcasm_ir::RelationType >( + node.type(), argTypeList ); + node.setType( type ); + } break; } case CallExpression::TargetType::DERIVED: @@ -1125,6 +1132,37 @@ const libcasm_ir::Annotation* TypeInferenceVisitor::annotate( annotation = &builtin_annotation; directCall.setTargetBuiltinId( builtin_annotation.id() ); + + if( builtin_annotation.id() + == libcasm_ir::Value::AS_BIT_BUILTIN ) + { + const auto& asbit_args = directCall.arguments()->data(); + assert( asbit_args.size() == 2 ); + const auto& asbit_size + = static_cast< const ValueAtom& >( + *asbit_args[ 1 ] ); + if( asbit_size.id() == Node::ID::VALUE_ATOM + and asbit_size.type() + and asbit_size.type()->id() + == libcasm_ir::Type::INTEGER ) + { + const auto asbit_size_value + = std::static_pointer_cast< libcasm_ir:: + IntegerConstant >( asbit_size.value() ); + + const auto type + = libstdhl::get< libcasm_ir::BitType >( + asbit_size_value ); + directCall.setType( type ); + } + else + { + m_log.error( { directCall.arguments()->data()[ 1 ]->sourceLocation() }, + "2nd argument of built-in '" + + path.path() + + "' is required to be a compile time 'Integer' constant value" ); + } + } } catch( const std::domain_error& e ) { @@ -1299,22 +1337,27 @@ void TypeInferenceVisitor::inference( const std::string& description, { const auto relation = annotation->resultTypeForRelation( argTypes ); - const std::vector< libcasm_ir::Type::ID > inf - = { relation->result }; - std::vector< libcasm_ir::Type::ID > tmp = {}; + if( relation ) + { + const std::vector< libcasm_ir::Type::ID > inf + = { relation->result }; + std::vector< libcasm_ir::Type::ID > tmp = {}; - std::set_intersection( result->second.begin(), result->second.end(), - inf.begin(), inf.end(), std::back_inserter( tmp ) ); + std::set_intersection( result->second.begin(), + result->second.end(), inf.begin(), inf.end(), + std::back_inserter( tmp ) ); - m_resultTypes[&node ] = std::move( tmp ); + m_resultTypes[&node ] = std::move( tmp ); - if( pos != -1 and arguments[ pos ]->id() == Node::ID::UNDEF_ATOM ) - { - auto undefAtom - = static_cast< UndefAtom* >( arguments[ pos ].get() ); + if( pos != -1 + and arguments[ pos ]->id() == Node::ID::UNDEF_ATOM ) + { + auto undefAtom + = static_cast< UndefAtom* >( arguments[ pos ].get() ); - m_resultTypes[ undefAtom ] = { relation->argument[ pos ] }; - inference( "undef", 0, *undefAtom ); + m_resultTypes[ undefAtom ] = { relation->argument[ pos ] }; + inference( "undef", 0, *undefAtom ); + } } } catch( const std::invalid_argument& e ) @@ -1343,39 +1386,6 @@ void TypeInferenceVisitor::inference( const std::string& description, return; } - if( arguments.size() > 0 ) - { - u1 invalid = false; - - for( auto argument : arguments ) - { - const auto argTypes = m_resultTypes[&( *argument ) ]; - if( argTypes.size() != 1 ) - { - invalid = true; - - u1 first = true; - std::string tmp = " from multiple possible types: "; - for( auto t : resTypes ) - { - tmp += ( first ? "" : ", " ); - tmp += "'" + libcasm_ir::Type::token( t ) + "'"; - first = false; - } - - m_log.error( { node.sourceLocation() }, - "unable to infer result type of " + description - + ( resTypes.size() > 0 ? tmp : "" ) ); - } - } - - if( invalid ) - { - m_resultTypes[&node ].clear(); - return; - } - } - switch( *resTypes.begin() ) { case libcasm_ir::Type::VOID: @@ -1390,14 +1400,37 @@ void TypeInferenceVisitor::inference( const std::string& description, } case libcasm_ir::Type::INTEGER: { - // TODO: PPA: check for ranged integers - node.setType( INTEGER ); + node.setType( INTEGER ); // TODO: PPA: check for ranged integers break; } case libcasm_ir::Type::BIT: { - // node.setType( Bit(n) ); // depends on other bit sizes - assert( 0 ); // TODO: PPA: + if( node.type() ) + { + return; + } + + assert( arguments.size() > 0 ); + assert( annotation ); + + std::vector< libcasm_ir::Type::Ptr > argTypes = {}; + + for( auto argument : arguments ) + { + argTypes.emplace_back( argument->type() ); + } + + try + { + const auto type = annotation->inference( argTypes ); + node.setType( type ); + } + catch( const std::domain_error& e ) + { + m_log.error( { node.sourceLocation() }, + "unable to infer result type of " + description + ": " + + e.what() ); + } break; } case libcasm_ir::Type::STRING: @@ -1434,12 +1467,6 @@ void TypeInferenceVisitor::inference( const std::string& description, assert( 0 ); } } - - if( not node.type() ) - { - m_log.error( - { node.sourceLocation() }, "unable to resolve type of expression" ); - } } void TypeInferenceVisitor::inference( FunctionDefinition& node )