Skip to content

Commit

Permalink
Introduce afteranalyze macro (used in GC)
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Feb 10, 2020
1 parent 26d7eef commit a6e8bca
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 13 deletions.
30 changes: 26 additions & 4 deletions lib/gc.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ local GC = @record{
nitems: usize, nslots: usize, mitems: usize, nfrees: usize
}

global gc = GC{}

function GC:_hash(ptr: pointer): usize
return (@usize)(ptr) >> 3
end
Expand Down Expand Up @@ -528,15 +530,35 @@ function GC:get_size(ptr: pointer): usize
return 0
end

global gc = GC{}
function GC:_mark_statics()
## local emit_mark_static = hygienize(function(sym, symtype)
gc:add(&#[sym]#, #[symtype.size]#, GCFlags.ROOT, nilptr)
## end)

local function nelua_main(): cint <cimport,nodecl> end
##[[
afteranalyze(function()
local function search_scope(scope)
for i=1,#scope.symbols do
local sym = scope.symbols[i]
local symtype = sym.type or primtypes.any
if sym:is_static_vardecl() and symtype:has_pointer() and sym ~= gc then
emit_mark_static(sym, symtype)
end
end
end
search_scope(context.rootscope)
for _,childscope in ipairs(context.rootscope.children) do
search_scope(childscope)
end
end)
]]
end

local static_top: pointer <volatile>
local static_bottom: pointer <volatile>
local function nelua_main(): cint <cimport,nodecl> end

local function main(argc: cint, argv: cchar**): cint <entrypoint>
gc:start(&argc)
gc:_mark_statics()
local ret: cint = nelua_main()
gc:stop()
return ret
Expand Down
2 changes: 1 addition & 1 deletion lib/myarraytable.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ local allocator = @gc_allocator
size: usize,
data: span(T)
}
global ArrayTableT <codename #['nelua_ArrayTable_'..T.name]#> = @record{
local ArrayTableT <codename #['nelua_ArrayTable_'..T.name]#> = @record{
impl: ArrayTableImplT*
}

Expand Down
15 changes: 15 additions & 0 deletions nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ end

function visitors.Id(context, node)
local name = node[1]
if node.attr.foreignsymbol then
local symbol = node.attr.foreignsymbol
symbol:link_node(node)
return symbol
end
local symbol = context.scope:get_symbol(name)
if not symbol then
if context.pragmas.strict then
Expand Down Expand Up @@ -1260,6 +1265,7 @@ function visitors.VarDecl(context, node)
end
for _,varnode,valnode,valtype in izipargnodes(varnodes, valnodes) do
assert(varnode.tag == 'IdDecl')
varnode.attr.vardecl = true
if varscope == 'global' then
if not context.scope:is_topscope() then
varnode:raisef("global variables can only be declared in top scope")
Expand Down Expand Up @@ -1809,6 +1815,15 @@ function analyzer.analyze(ast, parser, context)
local resolutions_count = context.rootscope:resolve()
until resolutions_count == 0

for _,cb in ipairs(context.afteranalyze) do
local ok, err = except.trycall(function()
cb.f()
end)
if not ok then
cb.node:raisef('error while executing after analyze: %s', err)
end
end

-- phase 3 traverse: infer unset types to 'any' type
local state = context:push_state()
state.anyphase = true
Expand Down
4 changes: 3 additions & 1 deletion nelua/analyzercontext.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ function AnalyzerContext:_init(visitors, parser)
self.pragmas = {}
self.pragmastack = {}
self.usedcodenames = {}
self.afteranalyze = {}
end

function AnalyzerContext:push_pragmas()
Expand All @@ -43,11 +44,12 @@ function AnalyzerContext:push_forked_scope(kind, node)
scope = node.scope
assert(scope.kind == kind)
assert(scope.parent == self.scope)
assert(scope.node == node)

-- symbols will be repopulated again
scope:clear_symbols()
else
scope = self.scope:fork(kind)
scope = self.scope:fork(kind, node)
node.scope = scope
end
self:push_scope(scope)
Expand Down
8 changes: 8 additions & 0 deletions nelua/attr.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,12 @@ function Attr:is_empty()
return next(self) == nil
end

function Attr:is_static_vardecl()
if self.vardecl and self.staticstorage and not self.comptime then
if not self.type or self.type.size > 0 then
return true
end
end
end

return Attr
6 changes: 5 additions & 1 deletion nelua/ppcontext.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ function PPContext:tovalue(val, orignode)
node = aster.String{val}
elseif traits.is_symbol(val) then
node = aster.Id{val.name}
node.pattr = val
local pattr = Attr({
foreignsymbol = val
})
node.attr:merge(pattr)
node.pattr = pattr
elseif traits.is_number(val) or traits.is_bignumber(val) then
local num = bn.new(val)
local neg = false
Expand Down
6 changes: 6 additions & 0 deletions nelua/preprocessor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ function preprocessor.preprocess(context, ast)
return tabler.unpack(rets)
end
end,
afteranalyze = function(f)
if not traits.is_function(f) then
raise_preprocess_error("invalid arguments for preprocess function")
end
table.insert(context.afteranalyze, { f=f, node = context:get_current_node() })
end,
afterinfer = function(f)
if not traits.is_function(f) then
raise_preprocess_error("invalid arguments for preprocess function")
Expand Down
16 changes: 11 additions & 5 deletions nelua/scope.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@ local stringer = require 'nelua.utils.stringer'

local Scope = class()

function Scope:_init(parent, kind)
function Scope:_init(parent, kind, node)
self.kind = kind
self.node = node
if kind == 'root' then
self.context = parent
else
self.parent = parent
self.context = parent.context
table.insert(parent.children, self)
end
self.children = {}
self.checkpointstack = {}
self:clear_symbols()
end

function Scope:fork(kind)
return Scope(self, kind)
function Scope:fork(kind, node)
return Scope(self, kind, node)
end

function Scope:is_topscope()
Expand Down Expand Up @@ -153,6 +156,7 @@ function Scope:add_symbol(symbol, annon)
end
end
self.symbols[key] = symbol
table.insert(self.symbols, symbol)
return true
end

Expand All @@ -164,7 +168,8 @@ function Scope:resolve_symbols()
local count = 0
local unknownlist = {}
-- first resolve any symbol with known possible types
for _,symbol in pairs(self.symbols) do
for i=1,#self.symbols do
local symbol = self.symbols[i]
if symbol:resolve_type() then
count = count + 1
elseif count == 0 and symbol.type == nil then
Expand All @@ -175,7 +180,8 @@ function Scope:resolve_symbols()
if count == 0 and #unknownlist > 0 and not self.context.rootscope.delay then
-- [disabled] try to infer the type only for the first unknown symbol
--table.sort(unknownlist, function(a,b) return a.node.pos < b.node.pos end)
for _,symbol in ipairs(unknownlist) do
for i=1,#unknownlist do
local symbol = unknownlist[i]
if symbol:resolve_type(true) then
count = count + 1
elseif self.context.state.anyphase then
Expand Down
23 changes: 23 additions & 0 deletions nelua/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ function Type:is_unsigned() return self.unsigned end
function Type:is_signed() return self.arithmetic and not self.unsigned end
function Type:is_generic_pointer() return self.genericpointer end
function Type.is_pointer_of() return false end
function Type.has_pointer() return false end

function Type:__tostring()
return self.name
Expand Down Expand Up @@ -489,6 +490,8 @@ function AnyType.is_convertible_from_type()
return true
end

function AnyType.has_pointer() return true end

--------------------------------------------------------------------------------
local VaranysType = typeclass(AnyType)
types.VaranysType = VaranysType
Expand Down Expand Up @@ -929,6 +932,8 @@ function ArrayTableType:__tostring()
return sstream(self.name, '(', self.subtype, ')'):tostring()
end

function ArrayTableType.has_pointer() return true end

ArrayTableType.unary_operators.len = 'integer'

--------------------------------------------------------------------------------
Expand Down Expand Up @@ -963,6 +968,10 @@ function ArrayType:is_convertible_from_type(type, explicit)
return Type.is_convertible_from_type(self, type, explicit)
end

function ArrayType:has_pointer()
return self.subtype:has_pointer()
end

ArrayType.unary_operators.len = function(ltype)
return primtypes.integer, bn.new(ltype.length)
end
Expand Down Expand Up @@ -1222,6 +1231,12 @@ function RecordType:is_convertible_from_type(type, explicit)
return Type.is_convertible_from_type(self, type, explicit)
end

function RecordType:has_pointer()
return tabler.ifindif(self.fields, function(f)
return f.type:has_pointer()
end) ~= nil
end

--------------------------------------------------------------------------------
local PointerType = typeclass()
types.PointerType = PointerType
Expand Down Expand Up @@ -1307,6 +1322,10 @@ function PointerType:__tostring()
end
end

function PointerType.has_pointer()
return true
end

--------------------------------------------------------------------------------
local SpanType = typeclass(RecordType)
types.SpanType = SpanType
Expand Down Expand Up @@ -1335,6 +1354,10 @@ function SpanType:__tostring()
return sstream(self.name, '(', self.subtype, ')'):tostring()
end

function SpanType.has_pointer()
return true
end

--------------------------------------------------------------------------------
local RangeType = typeclass(RecordType)
types.RangeType = RangeType
Expand Down
4 changes: 4 additions & 0 deletions nelua/visitorcontext.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ function VisitorContext:get_parent_node()
return self.visiting_nodes[#self.visiting_nodes - 1]
end

function VisitorContext:get_current_node()
return self.visiting_nodes[#self.visiting_nodes]
end

function VisitorContext:push_state()
table.insert(self.statestack, self.state)
local newstate = tabler.copy(self.state)
Expand Down
2 changes: 1 addition & 1 deletion spec/03-typechecker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ it("analyzed ast transform", function()
n.VarDecl{'local',
{ n.IdDecl{
assign=true,
attr = {codename='a', name='a', staticstorage=true, type='int64', lvalue=true},
attr = {codename='a', name='a', staticstorage=true, type='int64', vardecl=true, lvalue=true},
'a' }},
{ n.Number{
attr = {
Expand Down
53 changes: 53 additions & 0 deletions spec/05-cgenerator_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1807,4 +1807,57 @@ it("top scope variables prefix", function()
assert.config.srcname = nil
end)

it("GC requirements", function()
assert.generate_c([=[
global gp: pointer
global gs: span(integer)
global gr: record{x: pointer}
global ga: integer*[4]
global g
local p: pointer
local s: span(integer)
local r: record{x: pointer}
local a: integer*[4]
local l
local function markp(what: pointer)
end
local function mark()
## emit_mark_static = hygienize(function(sym)
markp(&#[sym]#)
## end)
##[[
afteranalyze(function()
local function search_scope(scope)
for i=1,#scope.symbols do
local sym = scope.symbols[i]
local symtype = sym.type or primtypes.any
if sym:is_static_vardecl() and symtype:has_pointer() then
emit_mark_static(sym, symtype)
end
end
end
search_scope(context.rootscope)
for _,childscope in ipairs(context.rootscope.children) do
search_scope(childscope)
end
end)
]]
end
]=], [[void mark() {
markp((void*)(&gp));
markp((void*)(&gs));
markp((void*)(&gr));
markp((void*)(&ga));
markp((void*)(&g));
markp((void*)(&p));
markp((void*)(&s));
markp((void*)(&r));
markp((void*)(&a));
markp((void*)(&l));
}]])
end)

end)
6 changes: 6 additions & 0 deletions spec/06-preprocessor_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ it("call codes after inference", function()
assert.analyze_error("## afterinfer(false)", "invalid arguments for preprocess")
end)

it("call codes after analyze pass", function()
assert.analyze_ast("## afteranalyze(function() end)")
assert.analyze_error("## afteranalyze(function() error 'errmsg' end)", "errmsg")
assert.analyze_error("## afteranalyze(false)", "invalid arguments for preprocess")
end)

it("inject nodes", function()
assert.ast_type_equals([=[
## ppcontext:add_statnode(aster.Call{{aster.String{"hello"}}, aster.Id{'print'}, true})
Expand Down

0 comments on commit a6e8bca

Please sign in to comment.