diff --git a/lib/gc.nelua b/lib/gc.nelua index 3fe12d68..9e2f264c 100644 --- a/lib/gc.nelua +++ b/lib/gc.nelua @@ -39,6 +39,8 @@ local GC = @record{ nitems: usize, nslots: usize, mitems: usize, nfrees: usize } +global gc = GC{} + function GC:_hash(ptr: pointer): usize return (@usize)(ptr) >> 3 end @@ -528,15 +530,35 @@ function GC:get_size(ptr: pointer): usize return 0 end -global gc = GC{} +function GC:_mark_statics() + ## local emit_mark_static = hygienize(function(sym, symtype) + gc:add(&#[sym]#, #[symtype.size]#, GCFlags.ROOT, nilptr) + ## end) -local function nelua_main(): cint end + ##[[ + afteranalyze(function() + local function search_scope(scope) + for i=1,#scope.symbols do + local sym = scope.symbols[i] + local symtype = sym.type or primtypes.any + if sym:is_static_vardecl() and symtype:has_pointer() and sym ~= gc then + emit_mark_static(sym, symtype) + end + end + end + search_scope(context.rootscope) + for _,childscope in ipairs(context.rootscope.children) do + search_scope(childscope) + end + end) + ]] +end -local static_top: pointer -local static_bottom: pointer +local function nelua_main(): cint end local function main(argc: cint, argv: cchar**): cint gc:start(&argc) + gc:_mark_statics() local ret: cint = nelua_main() gc:stop() return ret diff --git a/lib/myarraytable.nelua b/lib/myarraytable.nelua index 24bb1db5..059f372c 100644 --- a/lib/myarraytable.nelua +++ b/lib/myarraytable.nelua @@ -15,7 +15,7 @@ local allocator = @gc_allocator size: usize, data: span(T) } - global ArrayTableT = @record{ + local ArrayTableT = @record{ impl: ArrayTableImplT* } diff --git a/nelua/analyzer.lua b/nelua/analyzer.lua index 562e92d6..5af9fefa 100644 --- a/nelua/analyzer.lua +++ b/nelua/analyzer.lua @@ -299,6 +299,11 @@ end function visitors.Id(context, node) local name = node[1] + if node.attr.foreignsymbol then + local symbol = node.attr.foreignsymbol + symbol:link_node(node) + return symbol + end local symbol = context.scope:get_symbol(name) if not symbol then if context.pragmas.strict then @@ -1260,6 +1265,7 @@ function visitors.VarDecl(context, node) end for _,varnode,valnode,valtype in izipargnodes(varnodes, valnodes) do assert(varnode.tag == 'IdDecl') + varnode.attr.vardecl = true if varscope == 'global' then if not context.scope:is_topscope() then varnode:raisef("global variables can only be declared in top scope") @@ -1809,6 +1815,15 @@ function analyzer.analyze(ast, parser, context) local resolutions_count = context.rootscope:resolve() until resolutions_count == 0 + for _,cb in ipairs(context.afteranalyze) do + local ok, err = except.trycall(function() + cb.f() + end) + if not ok then + cb.node:raisef('error while executing after analyze: %s', err) + end + end + -- phase 3 traverse: infer unset types to 'any' type local state = context:push_state() state.anyphase = true diff --git a/nelua/analyzercontext.lua b/nelua/analyzercontext.lua index 54eacb40..c30802a3 100644 --- a/nelua/analyzercontext.lua +++ b/nelua/analyzercontext.lua @@ -18,6 +18,7 @@ function AnalyzerContext:_init(visitors, parser) self.pragmas = {} self.pragmastack = {} self.usedcodenames = {} + self.afteranalyze = {} end function AnalyzerContext:push_pragmas() @@ -43,11 +44,12 @@ function AnalyzerContext:push_forked_scope(kind, node) scope = node.scope assert(scope.kind == kind) assert(scope.parent == self.scope) + assert(scope.node == node) -- symbols will be repopulated again scope:clear_symbols() else - scope = self.scope:fork(kind) + scope = self.scope:fork(kind, node) node.scope = scope end self:push_scope(scope) diff --git a/nelua/attr.lua b/nelua/attr.lua index fd085730..9a120246 100644 --- a/nelua/attr.lua +++ b/nelua/attr.lua @@ -30,4 +30,12 @@ function Attr:is_empty() return next(self) == nil end +function Attr:is_static_vardecl() + if self.vardecl and self.staticstorage and not self.comptime then + if not self.type or self.type.size > 0 then + return true + end + end +end + return Attr diff --git a/nelua/ppcontext.lua b/nelua/ppcontext.lua index 72de45f1..428869ee 100644 --- a/nelua/ppcontext.lua +++ b/nelua/ppcontext.lua @@ -73,7 +73,11 @@ function PPContext:tovalue(val, orignode) node = aster.String{val} elseif traits.is_symbol(val) then node = aster.Id{val.name} - node.pattr = val + local pattr = Attr({ + foreignsymbol = val + }) + node.attr:merge(pattr) + node.pattr = pattr elseif traits.is_number(val) or traits.is_bignumber(val) then local num = bn.new(val) local neg = false diff --git a/nelua/preprocessor.lua b/nelua/preprocessor.lua index 1fd1ee93..e49ab906 100644 --- a/nelua/preprocessor.lua +++ b/nelua/preprocessor.lua @@ -159,6 +159,12 @@ function preprocessor.preprocess(context, ast) return tabler.unpack(rets) end end, + afteranalyze = function(f) + if not traits.is_function(f) then + raise_preprocess_error("invalid arguments for preprocess function") + end + table.insert(context.afteranalyze, { f=f, node = context:get_current_node() }) + end, afterinfer = function(f) if not traits.is_function(f) then raise_preprocess_error("invalid arguments for preprocess function") diff --git a/nelua/scope.lua b/nelua/scope.lua index 3b458e68..6b665e6b 100644 --- a/nelua/scope.lua +++ b/nelua/scope.lua @@ -8,20 +8,23 @@ local stringer = require 'nelua.utils.stringer' local Scope = class() -function Scope:_init(parent, kind) +function Scope:_init(parent, kind, node) self.kind = kind + self.node = node if kind == 'root' then self.context = parent else self.parent = parent self.context = parent.context + table.insert(parent.children, self) end + self.children = {} self.checkpointstack = {} self:clear_symbols() end -function Scope:fork(kind) - return Scope(self, kind) +function Scope:fork(kind, node) + return Scope(self, kind, node) end function Scope:is_topscope() @@ -153,6 +156,7 @@ function Scope:add_symbol(symbol, annon) end end self.symbols[key] = symbol + table.insert(self.symbols, symbol) return true end @@ -164,7 +168,8 @@ function Scope:resolve_symbols() local count = 0 local unknownlist = {} -- first resolve any symbol with known possible types - for _,symbol in pairs(self.symbols) do + for i=1,#self.symbols do + local symbol = self.symbols[i] if symbol:resolve_type() then count = count + 1 elseif count == 0 and symbol.type == nil then @@ -175,7 +180,8 @@ function Scope:resolve_symbols() if count == 0 and #unknownlist > 0 and not self.context.rootscope.delay then -- [disabled] try to infer the type only for the first unknown symbol --table.sort(unknownlist, function(a,b) return a.node.pos < b.node.pos end) - for _,symbol in ipairs(unknownlist) do + for i=1,#unknownlist do + local symbol = unknownlist[i] if symbol:resolve_type(true) then count = count + 1 elseif self.context.state.anyphase then diff --git a/nelua/types.lua b/nelua/types.lua index 04d21c22..0a4e9333 100644 --- a/nelua/types.lua +++ b/nelua/types.lua @@ -189,6 +189,7 @@ function Type:is_unsigned() return self.unsigned end function Type:is_signed() return self.arithmetic and not self.unsigned end function Type:is_generic_pointer() return self.genericpointer end function Type.is_pointer_of() return false end +function Type.has_pointer() return false end function Type:__tostring() return self.name @@ -489,6 +490,8 @@ function AnyType.is_convertible_from_type() return true end +function AnyType.has_pointer() return true end + -------------------------------------------------------------------------------- local VaranysType = typeclass(AnyType) types.VaranysType = VaranysType @@ -929,6 +932,8 @@ function ArrayTableType:__tostring() return sstream(self.name, '(', self.subtype, ')'):tostring() end +function ArrayTableType.has_pointer() return true end + ArrayTableType.unary_operators.len = 'integer' -------------------------------------------------------------------------------- @@ -963,6 +968,10 @@ function ArrayType:is_convertible_from_type(type, explicit) return Type.is_convertible_from_type(self, type, explicit) end +function ArrayType:has_pointer() + return self.subtype:has_pointer() +end + ArrayType.unary_operators.len = function(ltype) return primtypes.integer, bn.new(ltype.length) end @@ -1222,6 +1231,12 @@ function RecordType:is_convertible_from_type(type, explicit) return Type.is_convertible_from_type(self, type, explicit) end +function RecordType:has_pointer() + return tabler.ifindif(self.fields, function(f) + return f.type:has_pointer() + end) ~= nil +end + -------------------------------------------------------------------------------- local PointerType = typeclass() types.PointerType = PointerType @@ -1307,6 +1322,10 @@ function PointerType:__tostring() end end +function PointerType.has_pointer() + return true +end + -------------------------------------------------------------------------------- local SpanType = typeclass(RecordType) types.SpanType = SpanType @@ -1335,6 +1354,10 @@ function SpanType:__tostring() return sstream(self.name, '(', self.subtype, ')'):tostring() end +function SpanType.has_pointer() + return true +end + -------------------------------------------------------------------------------- local RangeType = typeclass(RecordType) types.RangeType = RangeType diff --git a/nelua/visitorcontext.lua b/nelua/visitorcontext.lua index 95c199b9..6abccbc1 100644 --- a/nelua/visitorcontext.lua +++ b/nelua/visitorcontext.lua @@ -53,6 +53,10 @@ function VisitorContext:get_parent_node() return self.visiting_nodes[#self.visiting_nodes - 1] end +function VisitorContext:get_current_node() + return self.visiting_nodes[#self.visiting_nodes] +end + function VisitorContext:push_state() table.insert(self.statestack, self.state) local newstate = tabler.copy(self.state) diff --git a/spec/03-typechecker_spec.lua b/spec/03-typechecker_spec.lua index e6ce00f5..9cd6a11f 100644 --- a/spec/03-typechecker_spec.lua +++ b/spec/03-typechecker_spec.lua @@ -13,7 +13,7 @@ it("analyzed ast transform", function() n.VarDecl{'local', { n.IdDecl{ assign=true, - attr = {codename='a', name='a', staticstorage=true, type='int64', lvalue=true}, + attr = {codename='a', name='a', staticstorage=true, type='int64', vardecl=true, lvalue=true}, 'a' }}, { n.Number{ attr = { diff --git a/spec/05-cgenerator_spec.lua b/spec/05-cgenerator_spec.lua index 2081ac0c..a4fbfd2d 100644 --- a/spec/05-cgenerator_spec.lua +++ b/spec/05-cgenerator_spec.lua @@ -1807,4 +1807,57 @@ it("top scope variables prefix", function() assert.config.srcname = nil end) +it("GC requirements", function() + assert.generate_c([=[ + global gp: pointer + global gs: span(integer) + global gr: record{x: pointer} + global ga: integer*[4] + global g + local p: pointer + local s: span(integer) + local r: record{x: pointer} + local a: integer*[4] + local l + + local function markp(what: pointer) + end + + local function mark() + ## emit_mark_static = hygienize(function(sym) + markp(&#[sym]#) + ## end) + + ##[[ + afteranalyze(function() + local function search_scope(scope) + for i=1,#scope.symbols do + local sym = scope.symbols[i] + local symtype = sym.type or primtypes.any + if sym:is_static_vardecl() and symtype:has_pointer() then + emit_mark_static(sym, symtype) + end + end + end + search_scope(context.rootscope) + for _,childscope in ipairs(context.rootscope.children) do + search_scope(childscope) + end + end) + ]] + end + ]=], [[void mark() { + markp((void*)(&gp)); + markp((void*)(&gs)); + markp((void*)(&gr)); + markp((void*)(&ga)); + markp((void*)(&g)); + markp((void*)(&p)); + markp((void*)(&s)); + markp((void*)(&r)); + markp((void*)(&a)); + markp((void*)(&l)); +}]]) +end) + end) diff --git a/spec/06-preprocessor_spec.lua b/spec/06-preprocessor_spec.lua index d6fe2be5..28808fae 100644 --- a/spec/06-preprocessor_spec.lua +++ b/spec/06-preprocessor_spec.lua @@ -342,6 +342,12 @@ it("call codes after inference", function() assert.analyze_error("## afterinfer(false)", "invalid arguments for preprocess") end) +it("call codes after analyze pass", function() + assert.analyze_ast("## afteranalyze(function() end)") + assert.analyze_error("## afteranalyze(function() error 'errmsg' end)", "errmsg") + assert.analyze_error("## afteranalyze(false)", "invalid arguments for preprocess") +end) + it("inject nodes", function() assert.ast_type_equals([=[ ## ppcontext:add_statnode(aster.Call{{aster.String{"hello"}}, aster.Id{'print'}, true})