From 092064b8014e7bd05792f782de709c52f7f7255d Mon Sep 17 00:00:00 2001 From: Eduardo Bart Date: Sat, 15 Aug 2020 16:31:24 -0300 Subject: [PATCH] Refactoring on types file Add lots of comments in the compiler sources --- examples/record_inheretance.nelua | 2 +- lib/allocators/gc.nelua | 2 +- nelua/analyzer.lua | 106 ++--- nelua/analyzercontext.lua | 13 + nelua/attr.lua | 49 ++- nelua/ccontext.lua | 6 +- nelua/cgenerator.lua | 18 +- nelua/scope.lua | 44 +- nelua/syntaxdefs.lua | 4 +- nelua/typedefs.lua | 6 +- nelua/types.lua | 696 +++++++++++++++++++----------- nelua/utils/shaper.lua | 39 ++ nelua/utils/tabler.lua | 6 +- nelua/utils/traits.lua | 8 + rockspecs/nelua-dev-1.rockspec | 1 + spec/03-typechecker_spec.lua | 12 +- spec/05-cgenerator_spec.lua | 5 +- spec/tools/assert.lua | 1 + 18 files changed, 655 insertions(+), 363 deletions(-) create mode 100644 nelua/utils/shaper.lua diff --git a/examples/record_inheretance.nelua b/examples/record_inheretance.nelua index 0a233937..7f65e535 100644 --- a/examples/record_inheretance.nelua +++ b/examples/record_inheretance.nelua @@ -89,7 +89,7 @@ local function class(recordsym, basesym) end rectype.base = basetype else -- base record - assert(rectype:get_field('__kind'), 'missing __kind field') + assert(rectype.fields.__kind, 'missing __kind field') rectype.classes = {} rectype.methods = {} kindid = 0 diff --git a/lib/allocators/gc.nelua b/lib/allocators/gc.nelua index dd308bbf..b7eff3b3 100644 --- a/lib/allocators/gc.nelua +++ b/lib/allocators/gc.nelua @@ -515,7 +515,7 @@ function GC:_mark_statics() for i=1,#scope.symbols do local sym = scope.symbols[i] local symtype = sym.type or primtypes.any - if sym:is_static_vardecl() and + if sym:is_on_static_storage() and symtype:has_pointer() and not sym.cimport and sym ~= gc then diff --git a/nelua/analyzer.lua b/nelua/analyzer.lua index fafbfbdb..f4a74093 100644 --- a/nelua/analyzer.lua +++ b/nelua/analyzer.lua @@ -130,8 +130,8 @@ local function visitor_convert(context, parent, parentindex, vartype, valnode, v end local objsym local mtname - local varobjtype = vartype:auto_deref_type() - local valobjtype = valtype:auto_deref_type() + local varobjtype = vartype:implict_deref_type() + local valobjtype = valtype:implict_deref_type() local objtype if valobjtype.is_record then if vartype.is_cstring then @@ -250,7 +250,7 @@ local function visitor_Record_literal(context, node, littype) if not traits.is_string(fieldname) then childnode:raisef("only string literals are allowed in record's field names") end - field = littype:get_field(fieldname) + field = littype.fields[fieldname] fieldindex = field and field.index or nil parent = childnode parentindex = 2 @@ -314,7 +314,7 @@ function visitors.Table(context, node) local desiredtype = node.desiredtype node.attr.literal = true if desiredtype then - local objtype = desiredtype:auto_deref_type() + local objtype = desiredtype:implict_deref_type() if objtype.is_record and objtype.choose_braces_type then local err desiredtype, err = objtype.choose_braces_type(node[1]) @@ -342,18 +342,6 @@ function visitors.PragmaCall(_, node) node.done = true end -local function choose_type_symbol_names(context, symbol) - local type = symbol.value - if type:suggest_nickname(symbol.name) then - if symbol.staticstorage and symbol.codename then - type:set_codename(symbol.codename) - else - local codename = context:choose_codename(symbol.name) - type:set_codename(codename) - end - end -end - function visitors.Annotation(context, node, symbol) assert(symbol) local name = node[1] @@ -440,7 +428,7 @@ function visitors.Annotation(context, node, symbol) type:set_codename(codename) elseif name == 'packed' or name == 'aligned' then if objattr._type then - objattr:_update_sizealign() + objattr:update_fields() end end @@ -556,20 +544,11 @@ function visitors.TypeInstance(context, node, symbol) if symbol then local type = attr.value symbol.value = type - choose_type_symbol_names(context, symbol) - type.symbol = symbol + context:choose_type_symbol_names(symbol) end node.done = true end -local function retnodes_to_rettypes(retnodes) - local rettypes = {} - for i=1,#retnodes do - rettypes[i] = retnodes[i].attr.value - end - return rettypes -end - function visitors.FuncType(context, node) local attr = node.attr local argnodes, retnodes = node[1], node[2] @@ -585,7 +564,7 @@ function visitors.FuncType(context, node) argattrs[i] = Attr{type = argnode.attr.value} end end - local rettypes = retnodes_to_rettypes(retnodes) + local rettypes = types.typenodes_to_types(retnodes) local type = types.FunctionType(argattrs, rettypes, node) type.sideeffect = true attr.type = primtypes.type @@ -607,6 +586,7 @@ end function visitors.RecordType(context, node, symbol) local attr = node.attr local recordtype = types.RecordType({}, node) + recordtype.node = node attr.type = primtypes.type attr.value = recordtype if symbol then @@ -614,7 +594,7 @@ function visitors.RecordType(context, node, symbol) assert((not symbol.type or symbol.type == primtypes.type) and not symbol.value) symbol.type = primtypes.type symbol.value = recordtype - choose_type_symbol_names(context, symbol) + context:choose_type_symbol_names(symbol) recordtype.symbol = symbol end local fieldnodes = node[1] @@ -641,8 +621,6 @@ function visitors.EnumFieldType(context, node) numnode:raisef("in enum field '%s': %s", name, err) end field.value = numnode.attr.value - field.comptime = true - field.type = desiredtype end node.done = field return field @@ -666,8 +644,6 @@ function visitors.EnumType(context, node) fnode:raisef("first enum field requires an initial value", field.name) else field.value = fields[i-1].value + 1 - field.comptime = true - field.type = subtype end end if not subtype:is_inrange(field.value) then @@ -713,9 +689,10 @@ function visitors.PointerType(context, node) if subtypenode then context:traverse_node(subtypenode) local subtype = subtypenode.attr.value - attr.value = types.get_pointer_type(subtype) - if not attr.value then - node:raisef("subtype '%s' is invalid for 'pointer' type", subtype) + if not subtype.is_unpointable then + attr.value = types.PointerType(subtype) + else + node:raisef("subtype '%s' is not addressable thus cannot have a pointer", subtype) end else attr.value = primtypes.pointer @@ -1114,7 +1091,7 @@ end local function visitor_Record_FieldIndex(_, node, objtype, name) local attr = node.attr - local field = objtype:get_field(name) + local field = objtype.fields[name] local type = field and field.type if not type then node:raisef("cannot index field '%s' on record '%s'", name, objtype) @@ -1126,7 +1103,7 @@ end local function visitor_EnumType_FieldIndex(_, node, objtype, name) local attr = node.attr - local field = objtype:get_field(name) + local field = objtype.fields[name] if not field then node:raisef("cannot index field '%s' on enum '%s'", name, objtype) end @@ -1156,7 +1133,7 @@ local function visitor_RecordType_FieldIndex(context, node, objtype, name) -- declaration of record global function symbol.metafunc = true if node.tag == 'ColonIndex' then - symbol.metafuncselftype = types.get_pointer_type(objtype) + symbol.metafuncselftype = types.PointerType(objtype) end elseif inglobaldecl then -- declaration of record global variable @@ -1182,7 +1159,7 @@ local function visitor_RecordType_FieldIndex(context, node, objtype, name) end local function visitor_Type_FieldIndex(context, node, objtype, name) - objtype = objtype:auto_deref_type() + objtype = objtype:implict_deref_type() node.indextype = objtype if objtype.is_enum then return visitor_EnumType_FieldIndex(context, node, objtype, name) @@ -1202,7 +1179,7 @@ local function visitor_FieldIndex(context, node) local ret if objtype then local attr = node.attr - objtype = objtype:auto_deref_type() + objtype = objtype:implict_deref_type() if objtype.is_record then ret = visitor_Record_FieldIndex(context, node, objtype, name) elseif objtype.is_type then @@ -1293,7 +1270,7 @@ function visitors.ArrayIndex(context, node) if node.checked then return end local objtype = objnode.attr.type if objtype then - objtype = objtype:auto_deref_type() + objtype = objtype:implict_deref_type() if objtype.is_array then visitor_Array_ArrayIndex(context, node, objtype, objnode, indexnode) elseif objtype.is_record then @@ -1692,8 +1669,7 @@ function visitors.VarDecl(context, node) assert(valnode and valnode.attr.value) assignvaltype = vartype ~= valtype symbol.value = valnode.attr.value - choose_type_symbol_names(context, symbol) - symbol.value.symbol = symbol + context:choose_type_symbol_names(symbol) end if vartype and vartype.is_auto then @@ -1786,8 +1762,8 @@ function visitors.Return(context, node) local retnodes = node[1] context:traverse_nodes(retnodes) local funcscope = context.scope:get_parent_of_kind('function') or context.rootscope - if funcscope.returntypes then - for i,funcrettype,retnode,rettype in izipargnodes(funcscope.returntypes, retnodes) do + if funcscope.rettypes then + for i,funcrettype,retnode,rettype in izipargnodes(funcscope.rettypes, retnodes) do if rettype then if funcrettype then if rettype.is_niltype and not funcrettype.is_nilable then @@ -1885,14 +1861,14 @@ local function block_endswith_return(blocknode) return false end -local function check_function_returns(node, returntypes, blocknode) +local function check_function_returns(node, rettypes, blocknode) local attr = node.attr local functype = attr.type if not functype or functype.is_lazyfunction or attr.nodecl or attr.cimport or attr.hookmain then return end - if #returntypes > 0 then - local canbeempty = tabler.iall(returntypes, 'is_nilable') + if #rettypes > 0 then + local canbeempty = tabler.iall(rettypes, 'is_nilable') if not canbeempty and not block_endswith_return(blocknode) then node:raisef("a return statement is missing before function end") end @@ -1924,21 +1900,21 @@ local function visitor_FuncDef_variable(context, varscope, varnode) end local function visitor_FuncDef_returns(context, functype, retnodes) - local returntypes + local rettypes context:traverse_nodes(retnodes) if #retnodes > 0 then -- returns types are pre declared - returntypes = retnodes_to_rettypes(retnodes) + rettypes = types.typenodes_to_types(retnodes) - if #returntypes == 1 and returntypes[1].is_void then + if #rettypes == 1 and rettypes[1].is_void then -- single void type means no returns - returntypes = {} + rettypes = {} end - elseif functype and functype.is_procedure and not functype.returntypes.has_unknown then + elseif functype and functype.is_procedure and not functype.rettypes.has_unknown then -- use return types from previous traversal only if fully resolved - returntypes = functype.returntypes + rettypes = functype.rettypes end - return returntypes + return rettypes end function visitors.FuncDef(context, node, lazysymbol) @@ -1954,7 +1930,7 @@ function visitors.FuncDef(context, node, lazysymbol) end context:pop_state() - local returntypes = visitor_FuncDef_returns(context, node.attr.type, retnodes) + local rettypes = visitor_FuncDef_returns(context, node.attr.type, retnodes) -- repeat scope to resolve function variables and return types local islazyparent, argtypes, argattrs @@ -1963,7 +1939,7 @@ function visitors.FuncDef(context, node, lazysymbol) repeat funcscope = context:push_forked_cleaned_scope('function', node) - funcscope.returntypes = returntypes + funcscope.rettypes = rettypes context:traverse_nodes(argnodes) for i=1,#argnodes do local argnode = argnodes[i] @@ -1982,8 +1958,8 @@ function visitors.FuncDef(context, node, lazysymbol) context:pop_scope() until resolutions_count == 0 - if not islazyparent and not returntypes then - returntypes = funcscope.resolved_returntypes + if not islazyparent and not rettypes then + rettypes = funcscope.resolved_rettypes end -- set the function type @@ -1991,10 +1967,10 @@ function visitors.FuncDef(context, node, lazysymbol) if islazyparent then assert(not lazysymbol) if not type then - type = types.LazyFunctionType(argattrs, returntypes, node) + type = types.LazyFunctionType(argattrs, rettypes, node) end - elseif not returntypes.has_unknown then - type = types.FunctionType(argattrs, returntypes, node) + elseif not rettypes.has_unknown then + type = types.FunctionType(argattrs, rettypes, node) end if symbol then -- symbol may be nil in case of array/dot index @@ -2025,7 +2001,7 @@ function visitors.FuncDef(context, node, lazysymbol) end -- type checking for returns - check_function_returns(node, returntypes, blocknode) + check_function_returns(node, rettypes, blocknode) do -- handle attributes and annotations local attr = node.attr @@ -2118,7 +2094,7 @@ local function override_unary_op(context, node, opname, objnode, objtype) if not overridable_operators[opname] then return end if opname == 'len' then -- allow calling len on pointers for arrays/records - objtype = objtype:auto_deref_type() + objtype = objtype:implict_deref_type() end if not objtype.is_record then return end local mtname = '__' .. opname diff --git a/nelua/analyzercontext.lua b/nelua/analyzercontext.lua index afcf9b58..707fe17a 100644 --- a/nelua/analyzercontext.lua +++ b/nelua/analyzercontext.lua @@ -105,6 +105,19 @@ function AnalyzerContext:choose_codename(name) return name end +function AnalyzerContext:choose_type_symbol_names(symbol) + local type = symbol.value + if type:suggest_nickname(symbol.name) then + if symbol.staticstorage and symbol.codename then + type:set_codename(symbol.codename) + else + local codename = self:choose_codename(symbol.name) + type:set_codename(codename) + end + type.symbol = symbol + end +end + function AnalyzerContext:traceback() local nodes = self.visiting_nodes local ss = sstream() diff --git a/nelua/attr.lua b/nelua/attr.lua index e2fd02d3..8cadbae1 100644 --- a/nelua/attr.lua +++ b/nelua/attr.lua @@ -1,49 +1,76 @@ +-- Attr +-- +-- The Attr class and in short 'attr' is used by the compiler to store many +-- attributes associated to a symbol or a AST node during compilation. +-- Usually the AST nodes are liked to an attr, multiple nodes can be linked +-- to the same attr, this happens for example with variable identifiers. +-- The compiler can promote an attr to a symbol in case it have a named +-- identifier or in case it needs to perform type resolution. + local class = require 'nelua.utils.class' local tabler = require 'nelua.utils.tabler' local Attr = class() +-- Used to check if this table is an attr. Attr._attr = true +-- Initialize an attr from a table of fields. function Attr:_init(attr) if attr then tabler.update(self, attr) end end +-- Clone the attr, shallow copying all fields. function Attr:clone() + -- getmetatable should be used here because this attr could be a promoted symbol + -- so we should copy its underlying metatable return setmetatable(tabler.copy(self), getmetatable(self)) end +-- Merge fields from another attr into this attr. +-- Mostly used when linking new nodes to the same attr. function Attr:merge(attr) for k,v in pairs(attr) do - if self[k] == nil then + if self[k] == nil then -- no collision self[k] = v - elseif k ~= 'attr' then - assert(self[k] == v, 'cannot combine different attributes') + else + -- when the field is already set + -- the merge is not permitted to overwrite to a new value otherwise + -- cause bugs on what already have been processed by the compiler + assert(self[k] == v, 'cannot combine different attrs') end end return self end -function Attr:is_static_vardecl() - if self.vardecl and self.staticstorage and not self.comptime then - if not self.type or self.type.size > 0 then - return true - end +-- Check if this attr is stored in the program static storage. +-- Used for example by the GC to know if a variable should be scanned. +function Attr:is_on_static_storage() + if self.vardecl and + self.staticstorage and + not self.comptime and + (not self.type or self.type.size > 0) + then + return true end + return false end +-- Check if this attr could be holding a negative arithmetic value. +-- Used by the C generator to optimize operations on non negatives values. function Attr:is_maybe_negative() local type = self.type - if type and type.is_arithmetic then - if type.is_unsigned then + if type and type.is_arithmetic then -- must be an arithmetic to proper check this + if type.is_unsigned then -- unsigned is never negative return false end - if self.comptime and self.value >= 0 then + if self.comptime and self.value >= 0 then -- comptime positive is never negative return false end end + -- could be negative if the type is unknown yet, or some union, any.. return true end diff --git a/nelua/ccontext.lua b/nelua/ccontext.lua index 4314c576..52731697 100644 --- a/nelua/ccontext.lua +++ b/nelua/ccontext.lua @@ -3,6 +3,7 @@ local class = require 'nelua.utils.class' local cdefs = require 'nelua.cdefs' local traits = require 'nelua.utils.traits' local cbuiltins = require 'nelua.cbuiltins' +local config = require 'nelua.configer'.get() local CContext = class(AnalyzerContext) @@ -71,6 +72,9 @@ function CContext:typename(type) until visitor if visitor then + if config.check_ast_shape then + assert(type:shape()) + end visitor(self, type) end return type.codename @@ -92,7 +96,7 @@ function CContext:runctype(type) end function CContext:funcretctype(functype) - if functype:has_enclosed_return() then + if functype:has_multiple_returns() then return functype.codename .. '_ret' else return self:ctype(functype:get_return_type(1)) diff --git a/nelua/cgenerator.lua b/nelua/cgenerator.lua index f701b2c3..eb99cea6 100644 --- a/nelua/cgenerator.lua +++ b/nelua/cgenerator.lua @@ -60,10 +60,10 @@ local function destroy_variable(context, emitter, vartype, varname) end local function destroy_callee_returns(context, emitter, retvalname, calleetype, ignoreindexes) - for i,returntype in ipairs(calleetype.returntypes) do + for i,returntype in ipairs(calleetype.rettypes) do if returntype:has_destroyable() and not ignoreindexes[i] then local retargname - if calleetype:has_enclosed_return() then + if calleetype:has_multiple_returns() then retargname = string.format('%s.r%d', retvalname, i) else retargname = retvalname @@ -571,7 +571,7 @@ local function visitor_Call(context, node, emitter, argnodes, callee, calleeobjn local handlereturns local retvalname local returnfirst - local enclosed = calleetype:has_enclosed_return() + local enclosed = calleetype:has_multiple_returns() local destroyable = calleetype:has_destroyable_return() if not attr.multirets and (enclosed or destroyable) then -- we are handling the returns @@ -739,8 +739,8 @@ function visitors.DotIndex(context, node, emitter) if objtype.is_type then objtype = node.indextype if objtype.is_enum then - local field = objtype:get_field(name) - emitter:add_numeric_literal(field) + local field = objtype.fields[name] + emitter:add_numeric_literal(field, objtype.subtype) elseif objtype.is_record then if attr.comptime then emitter:add_literal(attr) @@ -766,7 +766,7 @@ function visitors.ArrayIndex(context, node, emitter) local indexnode, objnode = node:args() local objtype = objnode.attr.type local pointer = false - if objtype.is_pointer and not objtype.is_genericpointer then + if objtype.is_pointer and not objtype.is_generic_pointer then -- indexing a pointer to an array objtype = objtype.subtype pointer = true @@ -843,7 +843,7 @@ function visitors.Return(context, node, emitter) else local functype = funcscope.functype local numfuncrets = functype:get_return_count() - if not functype:has_enclosed_return() then + if not functype:has_multiple_returns() then if numfuncrets == 0 then -- no returns assert(numretnodes == 0) @@ -871,7 +871,7 @@ function visitors.Return(context, node, emitter) retemitter:add('return (', funcretctype, '){') local ignoredestroyindexes = {} local usedlastcalletype - for i,funcrettype,retnode,rettype,lastcallindex,lastcalletype in izipargnodes(functype.returntypes, retnodes) do + for i,funcrettype,retnode,rettype,lastcallindex,lastcalletype in izipargnodes(functype.rettypes, retnodes) do if i>1 then retemitter:add(', ') end if lastcallindex == 1 then usedlastcalletype = lastcalletype @@ -1150,7 +1150,7 @@ function visitors.FuncDef(context, node, emitter) local decemitter, defemitter, implemitter = CEmitter(context), CEmitter(context), CEmitter(context) local retctype = context:funcretctype(type) - if type:has_enclosed_return() then + if type:has_multiple_returns() then node:assertraisef(declare, 'functions with multiple returns must be declared') local retemitter = CEmitter(context) diff --git a/nelua/scope.lua b/nelua/scope.lua index afbc9853..4e5f9789 100644 --- a/nelua/scope.lua +++ b/nelua/scope.lua @@ -50,8 +50,8 @@ function Scope:clear_symbols() else self.symbols = setmetatable({}, default_symbols_mt) end - self.possible_returntypes = {} - self.resolved_returntypes = {} + self.possible_rettypes = {} + self.resolved_rettypes = {} self.has_unknown_return = nil end @@ -104,8 +104,8 @@ end function Scope:make_checkpoint() local checkpoint = { symbols = tabler.copy(self.symbols), - possible_returntypes = tabler.copy(self.possible_returntypes), - resolved_returntypes = tabler.copy(self.resolved_returntypes), + possible_rettypes = tabler.copy(self.possible_rettypes), + resolved_rettypes = tabler.copy(self.resolved_rettypes), has_unknown_return = self.has_unknown_return } if self.parent and self.parent.kind ~= 'root' then @@ -116,11 +116,11 @@ end function Scope:set_checkpoint(checkpoint) tabler.clear(self.symbols) - tabler.clear(self.possible_returntypes) - tabler.clear(self.resolved_returntypes) + tabler.clear(self.possible_rettypes) + tabler.clear(self.resolved_rettypes) tabler.update(self.symbols, checkpoint.symbols) - tabler.update(self.possible_returntypes, checkpoint.possible_returntypes) - tabler.update(self.resolved_returntypes, checkpoint.resolved_returntypes) + tabler.update(self.possible_rettypes, checkpoint.possible_rettypes) + tabler.update(self.resolved_rettypes, checkpoint.resolved_rettypes) self.has_unknown_return = checkpoint.has_unknown_return if checkpoint.parentcheck then self.parent:set_checkpoint(checkpoint.parentcheck) @@ -129,8 +129,8 @@ end function Scope:merge_checkpoint(checkpoint) tabler.update(self.symbols, checkpoint.symbols) - tabler.update(self.possible_returntypes, checkpoint.possible_returntypes) - tabler.update(self.resolved_returntypes, checkpoint.resolved_returntypes) + tabler.update(self.possible_rettypes, checkpoint.possible_rettypes) + tabler.update(self.resolved_rettypes, checkpoint.resolved_rettypes) self.has_unknown_return = checkpoint.has_unknown_return if checkpoint.parentcheck then self.parent:merge_checkpoint(checkpoint.parentcheck) @@ -239,21 +239,21 @@ function Scope:add_return_type(index, type) if not type then self.has_unknown_return = true end - local returntypes = self.possible_returntypes[index] - if not returntypes then - self.possible_returntypes[index] = {[1] = type} - elseif type and not tabler.ifind(returntypes, type) then - returntypes[#returntypes+1] = type + local rettypes = self.possible_rettypes[index] + if not rettypes then + self.possible_rettypes[index] = {[1] = type} + elseif type and not tabler.ifind(rettypes, type) then + rettypes[#rettypes+1] = type end end -function Scope:resolve_returntypes() +function Scope:resolve_rettypes() local count = 0 - if not next(self.possible_returntypes) then return count end - local resolved_returntypes = self.resolved_returntypes - resolved_returntypes.has_unknown = self.has_unknown_return - for i,returntypes in pairs(self.possible_returntypes) do - resolved_returntypes[i] = types.find_common_type(returntypes) or typedefs.primtypes.any + if not next(self.possible_rettypes) then return count end + local resolved_rettypes = self.resolved_rettypes + resolved_rettypes.has_unknown = self.has_unknown_return + for i,rettypes in pairs(self.possible_rettypes) do + resolved_rettypes[i] = types.find_common_type(rettypes) or typedefs.primtypes.any count = count + 1 end return count @@ -261,7 +261,7 @@ end function Scope:resolve() local count = self:resolve_symbols() - self:resolve_returntypes() + self:resolve_rettypes() if count > 0 and config.debug_scope_resolve then console.info(self.node:format_message('info', "scope resolved %d symbols", count)) end diff --git a/nelua/syntaxdefs.lua b/nelua/syntaxdefs.lua index 96502e8c..3eea6a02 100644 --- a/nelua/syntaxdefs.lua +++ b/nelua/syntaxdefs.lua @@ -20,7 +20,7 @@ local function get_parser() -- shebang, e.g. "#!/usr/bin/nelua" parser:set_peg("SHEBANG", "'#!' (!%LINEBREAK .)*") - -- multiline and single line comments + -- multi-line and single line comments parser:set_pegs([[ %LONGCOMMENT <- (open (contents close / %{UnclosedLongComment})) -> 0 contents <- (!close .)* @@ -202,7 +202,7 @@ local function get_parser() %TENUM <- 'enum' ]]) - --- capture varargs values + -- capture varargs values parser:set_token_pegs([[ %cVARARGS <- ({} %ELLIPSIS -> 'Varargs') -> to_astnode ]]) diff --git a/nelua/typedefs.lua b/nelua/typedefs.lua index 76f18ebc..9203b9ca 100644 --- a/nelua/typedefs.lua +++ b/nelua/typedefs.lua @@ -31,8 +31,8 @@ primtypes.uint8 = types.IntegralType('uint8', 1, true) primtypes.uint16 = types.IntegralType('uint16', 2, true) primtypes.uint32 = types.IntegralType('uint32', 4, true) primtypes.uint64 = types.IntegralType('uint64', 8, true) -primtypes.float32 = types.FloatType('float32', 4, 9, 7) -primtypes.float64 = types.FloatType('float64', 8, 17, 15) +primtypes.float32 = types.FloatType('float32', 4, 9) +primtypes.float64 = types.FloatType('float64', 8, 17) primtypes.boolean = types.BooleanType('boolean', 1) primtypes.varanys = types.VaranysType('varanys') primtypes.table = types.TableType('table') @@ -65,7 +65,7 @@ primtypes.cdouble = primtypes.float64 primtypes.cfloat = primtypes.float32 -- complex types -primtypes.stringview = types.StringViewType('stringview', cpusize*2) +primtypes.stringview = types.StringViewType('stringview') -- signed types typedefs.integral_signed_types = { diff --git a/nelua/types.lua b/nelua/types.lua index c9a18f21..af5a1e87 100644 --- a/nelua/types.lua +++ b/nelua/types.lua @@ -1,3 +1,10 @@ +-- Types module +-- +-- The types module define classes for all the primitive types in Nelua. +-- Also defines some utilities functions for working with types. +-- +-- This module is always available in the preprocessor in the `types` variable. + local class = require 'nelua.utils.class' local tabler = require 'nelua.utils.tabler' local iters = require 'nelua.utils.iterators' @@ -8,6 +15,7 @@ local metamagic = require 'nelua.utils.metamagic' local config = require 'nelua.configer'.get() local bn = require 'nelua.utils.bn' local except = require 'nelua.utils.except' +local shaper = require 'nelua.utils.shaper' local typedefs, primtypes local types = {} @@ -17,6 +25,73 @@ local cpusize = config.cpu_bits // 8 local Type = class() types.Type = Type +-- Define the shape of all fields used in the type. +-- Use this as a reference to know all used fields in the Type class by the compiler. +Type.shape = shaper.shape { + -- Unique identifier for the type, used when needed for runtime type information. + id = shaper.integer, + -- Size of the type at runtime in bytes. + size = shaper.integer, + -- Size of the type at runtime in bits. + bitsize = shaper.integer, + -- Alignment for the type in bytes. + align = shaper.integer, + -- Short name of the type, e.g. 'int64', 'record', 'enum' ... + name = shaper.string, + -- First identifier name defined in the sources for the type, not applicable to primitive types. + -- It is used to generate a pretty name on code generation and to show name on type errors. + nickname = shaper.string:is_optional(), + -- The actual name of the type used in the code generator when emitting C code. + codename = shaper.string, + -- Symbol that defined the type, not applicable for primitive types. + symbol = shaper.symbol:is_optional(), + -- Node that defined the type. + node = shaper.astnode:is_optional(), + -- Compile time unary operators defined for the type. + unary_operators = shaper.table, + -- Compile time binary operators defined for the type. + binary_operators = shaper.table, + -- A generic type that the type can represent when used as generic. + generic = shaper.type:is_optional(), + -- Whether the code generator should omit the type declaration. + nodecl = shaper.optional_boolean, + -- Whether the code generator should is importing the type from C. + cimport = shaper.optional_boolean, + -- C header that the code generator should include C when using the type. + cinclude = shaper.string:is_optional(), + -- The value passed in annotation, this will change the computed align. + aligned = shaper.integer:is_optional(), + -- Whether the type is a primitive type, true for non user defined types. + is_primitive = shaper.optional_boolean, + -- Whether the type can turn represents a string, true for stringview, string and cstring. + is_stringy = shaper.optional_boolean, + -- Whether the type represents a contiguous buffer. + -- True for arrays, span and vector defined in the lib. + -- This is used to allow casting to/from span. + is_contiguous = shaper.optional_boolean, + -- Booleans for checking the underlying type. + is_generic_pointer = shaper.optional_boolean, + is_cstring = shaper.optional_boolean, + is_float32 = shaper.optional_boolean, + is_float64 = shaper.optional_boolean, + is_float128 = shaper.optional_boolean, + -- Booleans for checking the underlying type. (lib types) + is_allocator = shaper.optional_boolean, + is_resourcepool = shaper.optional_boolean, + is_string = shaper.optional_boolean, + is_span = shaper.optional_boolean, + is_vector = shaper.optional_boolean, + is_sequence = shaper.optional_boolean, + is_filestream = shaper.optional_boolean, + + -- REMOVE: + is_copyable = shaper.optional_boolean, + is_destroyable = shaper.optional_boolean, + + -- TODO: rethink + key = shaper.string:is_optional(), +} + Type._type = true Type.unary_operators = {} Type.binary_operators = {} @@ -47,8 +122,10 @@ function Type:_init(name, size, node) self.name = name self.node = node self.size = size or 0 + self.bitsize = self.size * 8 + self.align = self.size if not self.codename then - self:set_codename(string.format('nl%s', self.name)) + self:set_codename('nl' .. self.name) end local mt = getmetatable(self) self.unary_operators = setmetatable({}, {__index = mt.unary_operators}) @@ -105,9 +182,7 @@ function Type:is_convertible_from_type(type) -- anything can be converted to and from `any` return self else - return false, stringer.pformat( - "no viable type conversion from `%s` to `%s`", - type, self) + return false, stringer.pformat("no viable type conversion from `%s` to `%s`", type, self) end end @@ -171,7 +246,7 @@ function Type:is_initializable_from_attr(attr) end end -function Type:auto_deref_type() +function Type:implict_deref_type() return self end @@ -225,9 +300,13 @@ end Type.unary_operators.ref = function(ltype, lattr) local lval = lattr.value if lval == nil then - return types.get_pointer_type(ltype) + if not ltype.is_unpointable then + return types.PointerType(ltype) + else + return nil, nil, stringer.pformat('cannot reference not addressable type "%s"', ltype) + end else - return nil, nil, 'cannot reference compile time value' + return nil, nil, stringer.pformat('cannot reference compile time value of type "%s"', ltype) end end @@ -344,6 +423,7 @@ TypeType.is_comptime = true TypeType.nodecl = true TypeType.is_unpointable = true TypeType.is_lazyable = true +TypeType.is_primitive = true function TypeType:_init(name) Type._init(self, name, 0) @@ -441,6 +521,7 @@ local VaranysType = typeclass(AnyType) types.VaranysType = VaranysType VaranysType.is_varanys = true VaranysType.is_nolvalue = true +VaranysType.is_primitive = true function VaranysType:_init(name, size) Type._init(self, name, size) @@ -454,8 +535,6 @@ ArithmeticType.is_primitive = true function ArithmeticType:_init(name, size) Type._init(self, name, size) - self.align = size - self.bitsize = size * 8 end ArithmeticType.is_convertible_from_type = Type.is_convertible_from_type @@ -534,29 +613,47 @@ ArithmeticType.binary_operators.gt = make_arithmetic_cmp_opfunc(function(a,b) end) -------------------------------------------------------------------------------- +-- Integral Type +-- +-- Integral type is used for unsigned and signed integer (whole numbers) types, +-- i.e. 'int64', 'uint64', ... +-- They have min and max values and cannot be fractional. + local IntegralType = typeclass(ArithmeticType) types.IntegralType = IntegralType IntegralType.is_integral = true -local function get_integral_range(bits, is_unsigned) - local min, max - if is_unsigned then - min = bn.zero() - max = (bn.one() << bits) - 1 - else -- signed - min = -(bn.one() << bits) // 2 - max = ((bn.one() << bits) // 2) - 1 - end - return min, max -end +IntegralType.shape = shaper.fork_shape(Type.shape, { + -- Minimum and maximum value that the integral type can store. + min = shaper.arithmetic, max = shaper.arithmetic, + -- Signess of the integral type. + is_signed = shaper.optional_boolean, is_unsigned = shaper.optional_boolean, + -- Boolean to know the exactly underlying integral type. + is_uint64 = shaper.optional_boolean, + is_uint32 = shaper.optional_boolean, + is_uint16 = shaper.optional_boolean, + is_uint8 = shaper.optional_boolean, + is_int64 = shaper.optional_boolean, + is_int32 = shaper.optional_boolean, + is_int16 = shaper.optional_boolean, + is_int8 = shaper.optional_boolean, +}) function IntegralType:_init(name, size, is_unsigned) ArithmeticType._init(self, name, size) - self.min, self.max = get_integral_range(self.bitsize, is_unsigned) - self.is_unsigned = is_unsigned - self.is_signed = not is_unsigned - local isname = (is_unsigned and 'is_uint' or 'is_int')..self.bitsize - self[isname] = true + + -- compute the min and max values + if is_unsigned then + self.is_unsigned = true + self['is_uint'..self.bitsize] = true + self.min = bn.zero() + self.max = (bn.one() << self.bitsize) - 1 + else -- signed + self.is_signed = true + self['is_int'..self.bitsize] = true + self.min = -(bn.one() << self.bitsize) // 2 + self.max = ((bn.one() << self.bitsize) // 2) - 1 + end end function IntegralType:is_convertible_from_type(type, explicit) @@ -783,10 +880,18 @@ types.FloatType = FloatType FloatType.is_float = true FloatType.is_signed = true -function FloatType:_init(name, size, maxdigits, fmtdigits) +FloatType.shape = shaper.fork_shape(Type.shape, { + -- Max decimal digits that this float can represent. + maxdigits = shaper.integer, + -- Boolean to know the exactly underlying float type. + is_float32 = shaper.optional_boolean, + is_float64 = shaper.optional_boolean, + is_float128 = shaper.optional_boolean, +}) + +function FloatType:_init(name, size, maxdigits) ArithmeticType._init(self, name, size) self.maxdigits = maxdigits - self.fmtdigits = fmtdigits self['is_float'..self.bitsize] = true end @@ -866,6 +971,7 @@ end) -------------------------------------------------------------------------------- local TableType = typeclass() types.TableType = TableType +TableType.is_primitive = true TableType.is_table = true function TableType:_init(name) @@ -888,19 +994,24 @@ types.ArrayType = ArrayType ArrayType.is_array = true ArrayType.is_contiguous = true +ArrayType.shape = shaper.fork_shape(Type.shape, { + -- Fixed length for the array. + length = shaper.integer, + -- The sub type for the array. + subtype = shaper.type, +}) + function ArrayType:_init(subtype, length) local size = subtype.size * length self:set_codename(string.format('%s_arr%d', subtype.codename, length)) Type._init(self, 'array', size) self.subtype = subtype self.length = length - self.align = subtype.align or subtype.size + self.align = subtype.align end function ArrayType:is_equal(type) - return self.subtype == type.subtype and - self.length == type.length and - type.is_array + return self.subtype == type.subtype and self.length == type.length and type.is_array end function ArrayType:typedesc() @@ -909,7 +1020,7 @@ end function ArrayType:is_convertible_from_type(type, explicit) if not explicit and type:is_pointer_of(self) then - -- automatic deref + -- implicit automatic dereference return self end return Type.is_convertible_from_type(self, type, explicit) @@ -931,22 +1042,38 @@ end local EnumType = typeclass(IntegralType) types.EnumType = EnumType EnumType.is_enum = true -EnumType.is_primitive = false +EnumType.is_primitive = false -- to allow using custom nicknames + +EnumType.shape = shaper.fork_shape(IntegralType.shape, { + -- Fixed length for the array. + fields = shaper.array_of(shaper.shape{ + -- Name of the field. + name = shaper.string, + -- Index of the field in the enum, the first index is always 1 not 0. + index = shaper.integer, + -- The field value. + value = shaper.integral, + }), + -- The integral sub type for the enum. + subtype = shaper.type, +}) function EnumType:_init(subtype, fields) self:set_codename(gencodename(self, 'enum')) IntegralType._init(self, 'enum', subtype.size, subtype.is_unsigned) self.subtype = subtype + self.fields = fields + self:update_fields() +end + +-- Update fields internal values when they are changed. +function EnumType:update_fields() + local fields = self.fields for i=1,#fields do local field = fields[i] field.index = i fields[field.name] = field end - self.fields = fields -end - -function EnumType:get_field(name) - return self.fields[name] end function EnumType:typedesc() @@ -965,7 +1092,19 @@ types.FunctionType = FunctionType FunctionType.is_function = true FunctionType.is_procedure = true -function FunctionType:_init(argattrs, returntypes, node) +FunctionType.shape = shaper.fork_shape(Type.shape, { + -- List of arguments attrs, they contain the type with annotations. + argattrs = shaper.array_of(shaper.attr), + -- List of arguments types. + argtypes = shaper.array_of(shaper.type), + -- List of return types. + rettypes = shaper.array_of(shaper.type), + -- Whether this functions trigger side effects. + -- A function trigger side effects when it throw errors or operate on global variables. + sideeffect = shaper.optional_boolean, +}) + +function FunctionType:_init(argattrs, rettypes, node) self:set_codename(gencodename(self, 'function', node)) Type._init(self, 'function', cpusize, node) self.argattrs = argattrs or {} @@ -974,41 +1113,41 @@ function FunctionType:_init(argattrs, returntypes, node) argtypes[i] = argattrs[i].type end self.argtypes = argtypes - if returntypes then - if #returntypes == 1 and returntypes[1].is_void then + if rettypes then + if #rettypes == 1 and rettypes[1].is_void then -- single void type means no returns - self.returntypes = {} + self.rettypes = {} else - self.returntypes = returntypes - local lastindex = #returntypes - local lastret = returntypes[lastindex] + self.rettypes = rettypes + local lastindex = #rettypes + local lastret = rettypes[lastindex] self.returnvaranys = lastret and lastret.is_varanys end else - self.returntypes = {} + self.rettypes = {} end end function FunctionType:is_equal(type) return type.is_function and tabler.deepcompare(type.argtypes, self.argtypes) and - tabler.deepcompare(type.returntypes, self.returntypes) + tabler.deepcompare(type.rettypes, self.rettypes) end function FunctionType:has_destroyable_return() - for i=1,#self.returntypes do - if self.returntypes[i]:has_destroyable() then + for i=1,#self.rettypes do + if self.rettypes[i]:has_destroyable() then return true end end end function FunctionType:get_return_type(index) - local returntypes = self.returntypes - if self.returnvaranys and index > #returntypes then + local rettypes = self.rettypes + if self.returnvaranys and index > #rettypes then return primtypes.any end - local rettype = returntypes[index] + local rettype = rettypes[index] if rettype then return rettype elseif index == 1 then @@ -1017,22 +1156,18 @@ function FunctionType:get_return_type(index) end function FunctionType:has_multiple_returns() - return #self.returntypes > 1 -end - -function FunctionType:has_enclosed_return() - return self:has_multiple_returns() + return #self.rettypes > 1 end function FunctionType:get_return_count() - return #self.returntypes + return #self.rettypes end function FunctionType:is_convertible_from_type(type, explicit) if type.is_nilptr then return self end - if explicit and (type.is_genericpointer or type.is_function) then + if explicit and (type.is_generic_pointer or type.is_function) then return self end return Type.is_convertible_from_type(self, type, explicit) @@ -1040,12 +1175,12 @@ end function FunctionType:typedesc() local ss = sstream(self.name, '(', self.argtypes, ')') - if self.returntypes and #self.returntypes > 0 then + if self.rettypes and #self.rettypes > 0 then ss:add(': ') - if #self.returntypes > 1 then - ss:add('(', self.returntypes, ')') + if #self.rettypes > 1 then + ss:add('(', self.rettypes, ')') else - ss:add(self.returntypes) + ss:add(self.rettypes) end end return ss:tostring() @@ -1057,7 +1192,26 @@ types.LazyFunctionType = LazyFunctionType LazyFunctionType.is_procedure = true LazyFunctionType.is_lazyfunction = true -function LazyFunctionType:_init(args, returntypes, node) +LazyFunctionType.shape = shaper.fork_shape(Type.shape, { + -- List of arguments attrs, they contain the type with annotations. + args = shaper.array_of(shaper.attr), + -- List of arguments types. + argtypes = shaper.array_of(shaper.type), + -- List of return types. + rettypes = shaper.array_of(shaper.type), + -- List of functions evaluated by different argument types. + evals = shaper.array_of(shaper.shape{ + -- List of arguments attrs for the evaluation. + args = shaper.array_of(shaper.type), + -- Node defining the evaluated function. + node = shaper.astnode, + }), + -- Whether this functions trigger side effects. + -- A function trigger side effects when it throw errors or operate on global variables. + sideeffect = shaper.optional_boolean, +}) + +function LazyFunctionType:_init(args, rettypes, node) self:set_codename(gencodename(self, 'lazyfunction', node)) Type._init(self, 'lazyfunction', 0, node) self.args = args or {} @@ -1066,7 +1220,7 @@ function LazyFunctionType:_init(args, returntypes, node) argtypes[i] = args[i].type end self.argtypes = argtypes - self.returntypes = returntypes or {} + self.rettypes = rettypes or {} self.evals = {} end @@ -1108,100 +1262,125 @@ LazyFunctionType.is_equal = FunctionType.is_equal LazyFunctionType.typedesc = FunctionType.typedesc -------------------------------------------------------------------------------- +-- Record Type +-- +-- Record type is defined by a structure of fields, it really is the 'struct' under C. + local RecordType = typeclass() types.RecordType = RecordType RecordType.is_record = true -local function compute_pad(size, align) - if align <= 1 or size == 0 then return 0 end - if size % align == 0 then return 0 end - return align - (size % align) -end - -local function compute_record_size(fields, packed, aligned) - local nfields = #fields - local size = 0 - local align = 0 - if nfields == 0 then - return size, align - end - for i=1,#fields do - local ftype = fields[i].type - local fsize = ftype.size - local falign = ftype.align or fsize - align = math.max(align, falign) - if not packed then - size = size + compute_pad(size, falign) - end - size = size + fsize - end - if not packed then - size = size + compute_pad(size, align) - end - if aligned then - size = size + compute_pad(size, aligned) - align = math.max(aligned, align) - end - return size, align -end +RecordType.shape = shaper.fork_shape(Type.shape, { + -- Field in the record. + fields = shaper.array_of(shaper.shape{ + -- Name of the field. + name = shaper.string, + -- Index of the field in the record, the first index is always 1 not 0. + index = shaper.integer, + -- Offset of the field in the record in bytes, always properly aligned. + offset = shaper.integer, + -- Type of the field. + type = shaper.type, + }), + + -- Meta fields in the record (methods and global variables declared for it). + metafields = shaper.map_of(shaper.string, shaper.symbol), + + -- Function to determine which type to interpret when initializing the record from braces '{}'. + -- This is used to allow initialization of custom vectors from braces. + -- By default records interpret braces as fields initialization, + -- but it can be changed to an array for example then it's handled in the __convert metamethod. + choose_braces_type = shaper.func:is_optional(), + + -- Whether to pack the record. + packed = shaper.optional_boolean, + + -- Use in the lib in generics like 'span', 'vector' to represent the subtype. + subtype = shaper.type:is_optional(), +}) function RecordType:_init(fields, node) - fields = fields or {} - for i=1,#fields do - local field = fields[i] - field.index = i - fields[field.name] = field - end - local size, align = compute_record_size(fields) if not self.codename then self:set_codename(gencodename(self, 'record', node)) end - Type._init(self, 'record', size, node) - self.fields = fields + Type._init(self, 'record', 0, node) + + -- compute this record size and align according to the fields + self.fields = fields or {} self.metafields = {} - self.align = align + self:update_fields() +end + +-- Forward an offset to have a specified alignment. +local function align_forward(offset, align) + if align <= 1 or offset == 0 then return offset end + if offset % align == 0 then return offset end + return offset + (align - (offset % align)) end -function RecordType:_update_sizealign() - self.size, self.align = compute_record_size(self.fields, self.packed, self.aligned) +-- Update the record size, alignment and field offsets. +-- Called when changing any field at compile time. +function RecordType:update_fields() + local fields = self.fields + local offset, align = 0, 0 + if #fields > 0 then + local packed, aligned = self.packed, self.aligned + for i=1,#fields do + local field = fields[i] + local fieldtype = field.type + local fieldsize = fieldtype.size + local fieldalign = fieldtype.align + align = math.max(align, fieldalign) + if not packed then + offset = align_forward(offset, fieldalign) + end + field.offset = offset + field.index = i + fields[field.name] = field + offset = offset + fieldsize + end + if not packed then + offset = align_forward(offset, align) + end + if aligned then + offset = align_forward(offset, aligned) + align = math.max(aligned, align) + end + end + self.size = offset + self.bitsize = offset * 8 + self.align = align end +-- Add a field to the record. function RecordType:add_field(name, type, index) local fields = self.fields local field = {name = name, type = type} - if not index then + if not index then -- append a new field index = #fields + 1 fields[index] = field - else + else -- insert a new field at index table.insert(fields, index, field) end - field.index = index - self.fields[field.name] = field - self:_update_sizealign() + self:update_fields() end -function RecordType:get_field(name) +-- Get a field from the record. (deprecated, use 'fields' directly) +function RecordType:get_field(name) --luacov:disable return self.fields[name] -end +end --luacov:enable +-- Check if this type equals to another type. function RecordType:is_equal(type) return type.name == self.name and type.key == self.key end -function RecordType:typedesc() - local ss = sstream('record{') - for i,field in ipairs(self.fields) do - if i > 1 then ss:add(', ') end - ss:add(field.name, ':', field.type) - end - ss:add('}') - return ss:tostring() -end - +-- Get the symbol of a meta field for this record type. function RecordType:get_metafield(name) return self.metafields[name] end +-- Set a meta field for this record type to a symbol of a function or variable. function RecordType:set_metafield(name, symbol) if name == '__destroy' then self.is_destroyable = true @@ -1211,14 +1390,16 @@ function RecordType:set_metafield(name, symbol) self.metafields[name] = symbol end +-- Check if this type is convertible from another type. function RecordType:is_convertible_from_type(type, explicit) if not explicit and type:is_pointer_of(self) then - -- automatic deref + -- perform implicit automatic dereference on a pointer to this record return self end return Type.is_convertible_from_type(self, type, explicit) end +-- Check if this type can hold pointers, used by the garbage collector. function RecordType:has_pointer() local fields = self.fields for i=1,#fields do @@ -1245,16 +1426,32 @@ function RecordType:has_copyable() return false end +-- Return description of this type as a string. +function RecordType:typedesc() + local ss = sstream('record{') + for i,field in ipairs(self.fields) do + if i > 1 then ss:add(', ') end + ss:add(field.name, ':', field.type) + end + ss:add('}') + return ss:tostring() +end + -------------------------------------------------------------------------------- local PointerType = typeclass() types.PointerType = PointerType PointerType.is_pointer = true +PointerType.shape = shaper.fork_shape(Type.shape, { + -- The the the pointer is pointing to. + subtype = shaper.type, +}) + function PointerType:_init(subtype) self.subtype = subtype if subtype.is_void then self.nodecl = true - self.is_genericpointer = true + self.is_generic_pointer = true self.is_primitive = true elseif subtype.name == 'cchar' then self.nodecl = true @@ -1270,28 +1467,34 @@ function PointerType:_init(subtype) self.unary_operators['deref'] = subtype end +-- Check if this type is convertible from an attr. function PointerType:is_convertible_from_attr(attr, explicit) local type = attr.type if not explicit and self.subtype == type and (type.is_record or type.is_array) then - -- automatic ref - if not attr.lvalue then + -- implicit automatic reference for records and arrays + if not attr.lvalue then -- can only reference l-values return false, stringer.pformat( 'cannot automatic reference rvalue of type "%s" to pointer type "%s"', type, self) end + -- inform the code generation that the attr does an automatic reference attr.autoref = true return self end return Type.is_convertible_from_attr(self, attr, explicit) end +-- Check if this type is convertible from another type. function PointerType:is_convertible_from_type(type, explicit) if type == self then + -- early check for the same type (optimization) return self elseif type.is_pointer then if explicit then + -- explicit casting to any other pointer type return self - elseif self.is_genericpointer then + elseif self.is_generic_pointer then + -- implicit casting to a generic pointer return self elseif type.subtype:is_array_of(self.subtype) and type.subtype.length == 0 then -- implicit casting from unbounded arrays pointers to pointers @@ -1306,58 +1509,63 @@ function PointerType:is_convertible_from_type(type, explicit) return self elseif (self.is_cstring and type.subtype == primtypes.byte) or (type.is_cstring and self.subtype == primtypes.byte) then + -- implicit casting between cstring and pointer to byte return self end - elseif type.is_function and self.is_genericpointer and explicit then - return self - end - if type.is_stringview and (self.is_cstring or self:is_pointer_of(primtypes.byte)) then + elseif type.is_stringview and (self.is_cstring or self:is_pointer_of(primtypes.byte)) then + -- implicit casting a stringview to a cstring or pointer to a byte return self elseif type.is_nilptr then + -- implicit casting nilptr to a pointer return self - elseif explicit and type.is_integral and type.size == cpusize then - -- conversion from pointer to integral - return self + elseif explicit then + if type.is_function and self.is_generic_pointer then + -- explicit casting a function to a generic pointer + return self + elseif type.is_integral and type.size >= cpusize then + -- explicit casting a pointer to an integral that can fit a pointer + return self + end end return Type.is_convertible_from_type(self, type, explicit) end +-- Check if this type equals to another type. function PointerType:promote_type(type) - if type.is_nilptr then return self end + if type.is_nilptr then + return self + end return Type.promote_type(self, type) end +-- Check if this type equals to another type. function PointerType:is_equal(type) return type.subtype == self.subtype and type.is_pointer end +-- Check if this type is pointing to another type. function PointerType:is_pointer_of(subtype) return self.subtype == subtype end -function PointerType:auto_deref_type() +-- Give the underlying type when implicit dereferencing the pointer. +function PointerType:implict_deref_type() + -- implicit dereference is only allowed for records and arrays subtypes if self.subtype and self.subtype.is_record or self.subtype.is_array then return self.subtype end return self end -function PointerType:typedesc() - if not self.subtype.is_void then - return sstream(self.name, '(', self.subtype, ')'):tostring() - else - return self.name - end -end - +-- Check if this type can hold pointers, used by the garbage collector. function PointerType.has_pointer() return true end +-- Support for compile time length operator on cstring (pointer to cchar). PointerType.unary_operators.len = function(_, lattr) if lattr.type.is_cstring then - local lval = lattr.value - local reval + local lval, reval = lattr.value, nil if lval then reval = bn.new(#lval) end @@ -1365,50 +1573,73 @@ PointerType.unary_operators.len = function(_, lattr) end end +-- Return description of this type as a string. +function PointerType:typedesc() + if not self.subtype.is_void then + return sstream(self.name, '(', self.subtype, ')'):tostring() + else + return self.name + end +end + -------------------------------------------------------------------------------- +-- String View Type +-- +-- String views are used to store and process immutable strings at compile time +-- and also to store string references at runtime. Internally it just holds a pointer +-- to a buffer and a size. It's buffer is always null terminated ('\0') by default +-- to have more compatibility with C. + local StringViewType = typeclass(RecordType) types.StringViewType = StringViewType StringViewType.is_stringview = true StringViewType.is_stringy = true StringViewType.is_primitive = true -StringViewType.align = cpusize -function StringViewType:_init(name, size) - local fields = { +function StringViewType:_init(name) + self:set_codename('nlstringview') + self.nickname = name + RecordType._init(self, { {name = 'data', type = types.PointerType(types.ArrayType(primtypes.byte, 0)) }, {name = 'size', type = primtypes.usize} - } - self:set_codename('nlstringview') - RecordType._init(self, fields) - self.name = 'stringview' - self.nickname = 'stringview' - self.metafields = {} - Type._init(self, name, size) + }) + self.name = name end +-- Check if this type is convertible from another type. function StringViewType:is_convertible_from_type(type, explicit) - if type.is_cstring then - -- implicit cast cstring to stringview + if type.is_cstring then -- implicit cast cstring to stringview return self end return Type.is_convertible_from_type(self, type, explicit) end +-- Compile time string view length. StringViewType.unary_operators.len = function(_, lattr) - local lval = lattr.value - local reval + local lval, reval = lattr.value, nil if lval then reval = bn.new(#lval) end return primtypes.isize, reval end +-- Compile time string view concatenation. +StringViewType.binary_operators.concat = function(ltype, rtype, lattr, rattr) + if ltype.is_stringview and rtype.is_stringview then + local lval, rval, reval = lattr.value, rattr.value, nil + if lval and rval then -- both are compile time strings + reval = lval .. rval + end + return ltype, reval + end +end + +-- Utility to create the string view comparison functions at compile time. local function make_string_cmp_opfunc(cmpfunc) - return function(_, rtype, lattr, rattr) - if rtype.is_stringview then - local reval - local lval, rval = lattr.value, rattr.value - if lval and rval then + return function(ltype, rtype, lattr, rattr) + if ltype.is_stringview and rtype.is_stringview then -- comparing string views? + local lval, rval, reval = lattr.value, rattr.value, nil + if lval and rval then -- both are compile time strings reval = cmpfunc(lval, rval) end return primtypes.boolean, reval @@ -1416,30 +1647,17 @@ local function make_string_cmp_opfunc(cmpfunc) end end -StringViewType.binary_operators.le = make_string_cmp_opfunc(function(a,b) - return a<=b -end) -StringViewType.binary_operators.ge = make_string_cmp_opfunc(function(a,b) - return a>=b -end) -StringViewType.binary_operators.lt = make_string_cmp_opfunc(function(a,b) - return ab -end) -StringViewType.binary_operators.concat = function(ltype, rtype, lattr, rattr) - if rtype.is_stringview then - local reval - local lval, rval = lattr.value, rattr.value - if lval and rval then - reval = lval .. rval - end - return ltype, reval - end -end +-- Implement all the string view comparison functions. +StringViewType.binary_operators.le = make_string_cmp_opfunc(function(a,b) return a<=b end) +StringViewType.binary_operators.ge = make_string_cmp_opfunc(function(a,b) return a>=b end) +StringViewType.binary_operators.lt = make_string_cmp_opfunc(function(a,b) return ab end) -------------------------------------------------------------------------------- +-- Concept Type +-- +-- Concept type is used to choose or match incoming types to function arguments at compile time. + local ConceptType = typeclass() types.ConceptType = ConceptType ConceptType.nodecl = true @@ -1450,39 +1668,47 @@ ConceptType.is_lazyable = true ConceptType.is_nilable = true ConceptType.is_concept = true +-- Create a concept from a lua function defined in the preprocessor. function ConceptType:_init(func) Type._init(self, 'concept', 0) self.func = func end +-- Check if an attr can match a concept. function ConceptType:is_convertible_from_attr(attr, _, argattrs) local type, err = self.func(attr, argattrs) - if type == true then + if type == true then -- concept returned true, use the incoming type assert(attr.type) type = attr.type - elseif traits.is_symbol(type) then + elseif traits.is_symbol(type) then -- concept returned a symbol if type.type == primtypes.type and traits.is_type(type.value) then type = type.value - else + else -- the symbol is not holding a type type = nil err = stringer.pformat("invalid return for concept '%s': cannot be non type symbol", self) end - elseif traits.is_type(type) then - if type.is_comptime then - type = nil - err = stringer.pformat("invalid return for concept '%s': cannot be of the type '%s'", self, type) - end - elseif not type and not err then + elseif not type and not err then -- concept returned nothing type = nil err = stringer.pformat("type '%s' could not match concept '%s'", attr.type, self) - elseif not (type == false or type == nil) then + elseif not (type == false or type == nil or traits.is_type(type)) then + -- concept returned an invalid value type = nil err = stringer.pformat("invalid return for concept '%s': must be a boolean or a type", self) end + if type then + if type.is_comptime then -- concept cannot return compile time types + type = nil + err = stringer.pformat("invalid return for concept '%s': cannot be of the type '%s'", self, type) + end + end return type, err end -------------------------------------------------------------------------------- +-- Generic Type +-- +-- Generic type is used to create another type at compile time using the preprocessor. + local GenericType = typeclass() types.GenericType = GenericType GenericType.nodecl = true @@ -1496,76 +1722,66 @@ function GenericType:_init(func) self.func = func end +-- Evaluate a generic to a type by calling it's function defined in the preprocessor. function GenericType:eval_type(params) local ok, ret = except.trycall(self.func, table.unpack(params)) if not ok then + -- the generic creation failed due to a lua error in preprocessor function return nil, ret end local err - if traits.is_symbol(ret) then - if not ret.type or not ret.type.is_type then + if traits.is_symbol(ret) then -- generic returned a symbol + if ret.type == primtypes.type then -- the symbol is holding a type + ret = ret.value + else -- invalid symbol ret = nil err = stringer.pformat("expected a symbol holding a type in generic return, but got something else") - else - ret = ret.value end - elseif not traits.is_type(ret) then + elseif not traits.is_type(ret) then -- generic did not return a type ret = nil err = stringer.pformat("expected a type or symbol in generic return, but got '%s'", type(ret)) end return ret, err end +-- Permits evaluating generics by directly calling it's symbol in the preprocessor. function GenericType:__call(params) return self:eval_type({params}) end -------------------------------------------------------------------------------- -function types.set_typedefs(t) - typedefs = t - primtypes = t.primtypes -end - -function types.get_pointer_type(subtype) - if subtype == primtypes.cchar then - return primtypes.cstring - elseif not subtype.is_unpointable then - return types.PointerType(subtype) - end -end +-- Utilities +-- Promote all types from a list to a single common type. +-- Used on type resolution. function types.find_common_type(possibletypes) if not possibletypes then return end local commontype = possibletypes[1] for i=2,#possibletypes do commontype = commontype:promote_type(possibletypes[i]) - if not commontype then - break + if not commontype then -- no common type found + return nil end end - return commontype + return commontype -- found the common type end ---TODO: refactor to use this function ---luacov:disable -function types.are_types_convertible(largs, rargs) - for i,atype,btype in iters.izip(largs, rargs) do - if atype and btype then - local ok, err = btype:is_convertible_from(atype) - if not ok then - return nil, stringer.pformat("at index %d: %s", i, err) - end - elseif not atype then - if not btype.is_nilable then - return nil, stringer.format("at index %d: parameter of type '%s' is missing", i, atype) - end - else - assert(not btype and atype) - return nil, stringer.format("at index %d: extra parameter of type '%s':", i, atype) - end +-- Convert a list of nodes holding a type to a list of the holding types. +function types.typenodes_to_types(nodes) + local typelist = {} + for i=1,#nodes do + local nodeattr = nodes[i].attr + assert(nodeattr.type._type) + typelist[i] = nodes[i].attr.value end - return true + return typelist +end + +-- Used internally, set the typedefs and primtypes locals. +-- This exists because typedefs and types modules have recursive dependency on each other. +function types.set_typedefs(t) + typedefs = t + primtypes = t.primtypes end ---luacov:enable return types diff --git a/nelua/utils/shaper.lua b/nelua/utils/shaper.lua new file mode 100644 index 00000000..212b7a03 --- /dev/null +++ b/nelua/utils/shaper.lua @@ -0,0 +1,39 @@ +local shaper = require 'nelua.thirdparty.tableshape'.types +local tabler = require 'nelua.utils.tabler' +local traits = require 'nelua.utils.traits' + +-- Additional shape check functions. +shaper.arithmetic = shaper.custom(function(v) + return traits.is_arithmetic(v), 'expected an arithmetic' +end) + +shaper.integral = shaper.custom(function(v) + return traits.is_integral(v), 'expected an integral' +end) + +shaper.symbol = shaper.custom(function(v) + return traits.is_symbol(v), 'expected a symbol' +end) + +shaper.astnode = shaper.custom(function(v) + return traits.is_astnode(v), 'expected a node' +end) + +shaper.attr = shaper.custom(function(v) + return traits.is_attr(v), 'expected an attr' +end) + +shaper.type = shaper.custom(function(v) + return traits.is_type(v), 'expected a type' +end) + +shaper.optional_boolean = shaper.boolean:is_optional() + +-- Fork the shape definition from another shape definition. +function shaper.fork_shape(baseshape, desc) + local shape = shaper.shape(desc) + tabler.update(shape.shape, baseshape.shape) + return shape +end + +return shaper diff --git a/nelua/utils/tabler.lua b/nelua/utils/tabler.lua index ea53fcbd..368b950c 100644 --- a/nelua/utils/tabler.lua +++ b/nelua/utils/tabler.lua @@ -2,7 +2,7 @@ local metamagic = require 'nelua.utils.metamagic' local tabler = {} ---- copy a table into another, in-place. +-- Copy a table into another, in-place. function tabler.update(t, src) for k,v in next,src do t[k] = v @@ -10,7 +10,7 @@ function tabler.update(t, src) return t end --- find a value inside an array table +-- Find a value inside an array table. function tabler.ifind(t, val, idx) for i=idx or 1,#t do if t[i] == val then @@ -20,7 +20,7 @@ function tabler.ifind(t, val, idx) return nil end --- insert values +-- Insert values in a list table. function tabler.insertvalues(t, pos, st) if not st then st = pos diff --git a/nelua/utils/traits.lua b/nelua/utils/traits.lua index df7afcad..64ec57ef 100644 --- a/nelua/utils/traits.lua +++ b/nelua/utils/traits.lua @@ -36,4 +36,12 @@ function traits.is_type(v) return type(v) == 'table' and v._type end +function traits.is_arithmetic(v) + return type(v) == 'number' or (type(v) == 'table' and v._bn) +end + +function traits.is_integral(v) + return math.type(v) == 'integer' or (type(v) == 'table' and v._bn) +end + return traits diff --git a/rockspecs/nelua-dev-1.rockspec b/rockspecs/nelua-dev-1.rockspec index 2476a7bd..0a6dd0d6 100644 --- a/rockspecs/nelua-dev-1.rockspec +++ b/rockspecs/nelua-dev-1.rockspec @@ -86,6 +86,7 @@ build = { ['nelua.utils.metamagic'] = 'nelua/utils/metamagic.lua', ['nelua.utils.pegger'] = 'nelua/utils/pegger.lua', ['nelua.utils.platform'] = 'nelua/utils/platform.lua', + ['nelua.utils.shaper'] = 'nelua/utils/shaper.lua', ['nelua.utils.sstream'] = 'nelua/utils/sstream.lua', ['nelua.utils.stringer'] = 'nelua/utils/stringer.lua', ['nelua.utils.tabler'] = 'nelua/utils/tabler.lua', diff --git a/spec/03-typechecker_spec.lua b/spec/03-typechecker_spec.lua index 44c0c979..c34c75b3 100644 --- a/spec/03-typechecker_spec.lua +++ b/spec/03-typechecker_spec.lua @@ -245,6 +245,7 @@ it("unary operators", function() assert.ast_type_equals("local a = -1", "local a: integer = -1") assert.ast_type_equals("local a = -1.0", "local a: number = -1.0") assert.analyze_error("local x = &1", "cannot reference compile time value") + assert.analyze_error("local x: niltype; local b = &x", "cannot reference not addressable type") assert.analyze_error("local a = -'s'", "invalid operation") assert.ast_type_equals([[ local x = 1_usize * #@integer @@ -415,6 +416,13 @@ it("late deduction", function() b = 2 c = a + b ]]) + assert.ast_type_equals([[ + local a = 1 + a = true + ]],[[ + local a: any = 1 + a = true + ]]) assert.ast_type_equals([[ local a = 1_integer local b = a + 1 @@ -1239,8 +1247,8 @@ it("pointers", function() a = b ]], "no viable type conversion") assert.analyze_error([[local a: integer*, b: number*; b = a]], "no viable type conversion") - assert.analyze_error("local a: auto*", "is invalid for 'pointer' type") - assert.analyze_error("local a: type*", "is invalid for 'pointer' type") + assert.analyze_error("local a: auto*", "is not addressable thus cannot have a pointer") + assert.analyze_error("local a: type*", "is not addressable thus cannot have a pointer") end) it("dereferencing and referencing", function() diff --git a/spec/05-cgenerator_spec.lua b/spec/05-cgenerator_spec.lua index 63a37894..f695e89f 100644 --- a/spec/05-cgenerator_spec.lua +++ b/spec/05-cgenerator_spec.lua @@ -1666,9 +1666,8 @@ it("record operator overloading", function() assert((-r).x == 15) local vec2 = @record{x: number, y: number} - ## vec2.value.is_vec2 = true local is_vec2_or_arithmetic = #[concept(function(b) - return b.type.is_vec2 or b.type.is_arithmetic + return b.type.nickname == 'vec2' or b.type.is_arithmetic end)]# function vec2.__mul(a: is_vec2_or_arithmetic, b: is_vec2_or_arithmetic): vec2 ## if b.type.is_arithmetic then @@ -2374,7 +2373,7 @@ it("GC requirements", function() for i=1,#scope.symbols do local sym = scope.symbols[i] local symtype = sym.type or primtypes.any - if sym:is_static_vardecl() and symtype:has_pointer() then + if sym:is_on_static_storage() and symtype:has_pointer() then emit_mark_static(sym, symtype) end end diff --git a/spec/tools/assert.lua b/spec/tools/assert.lua index 01818417..2cc75fb0 100644 --- a/spec/tools/assert.lua +++ b/spec/tools/assert.lua @@ -251,6 +251,7 @@ local function filter_ast_for_check(t) for k,v in pairs(t) do if type(k) == 'number' then if traits.is_astnode(v) and v.attr.type and v.attr.type.is_type then + assert(v.attr.value:shape()) -- remove type nodes because they are optional t[k] = nil elseif type(v) == 'table' then