Skip to content

Commit

Permalink
Make arguments visible when preprocessing function returns
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Mar 26, 2021
1 parent 96becdc commit 5394c97
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 11 deletions.
29 changes: 27 additions & 2 deletions nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2458,10 +2458,35 @@ end
local function visitor_function_returns(context, node, retnodes, ispolyparent)
local funcscope = context.scope
context:push_state{intypeexpr = true}
context:traverse_nodes(retnodes)
local polyret = false
for i=1,#retnodes do
local retnode = retnodes[i]
if retnode.tag == 'PreprocessExpr' then -- must preprocess the return type
if ispolyparent then
-- skip parsing nodes that need preprocess in polymorphic function parent
retnode = nil
polyret = true
else
local ok, err = except.trycall(retnode.preprocess, retnodes, i)
if not ok then
if except.isexception(err) then
except.reraise(err)
else
retnode:raisef('error while preprocessing function return node: %s', err)
end
end
retnode = retnodes[i] -- the node was overwritten
end
end
if retnode then
context:traverse_node(retnode)
end
end
context:pop_state()
if not funcscope.rettypes then
if #retnodes > 0 then -- return types is fixed by the user
if polyret then
funcscope.rettypes = {}
elseif #retnodes > 0 then -- return types is fixed by the user
funcscope.rettypes = types.typenodes_to_types(retnodes)
elseif ispolyparent or node.attr.cimport then
funcscope.rettypes = {}
Expand Down
16 changes: 10 additions & 6 deletions nelua/ppcontext.lua
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,16 @@ function PPContext.toname(_, val, orignode)
return val
end

function PPContext:inject_value(val, srcnode, dest, destpos)
function PPContext:tonode(val, orignode)
local aster = self.context.parser.astbuilder.aster
local node = aster.value(val, orignode)
if not node then
orignode:raisef('unable to convert preprocess value of lua type "%s" to a compile time value', type(val))
end
return node
end

function PPContext:inject_value(val, orignode, dest, destpos)
if type(val) == 'table' and val._varargs then
while #dest > destpos do -- clean old varargs
dest[#dest] = nil
Expand All @@ -60,11 +68,7 @@ function PPContext:inject_value(val, srcnode, dest, destpos)
dest[destpos+i-1] = val[i]
end
else
local node = aster.value(val, srcnode)
if not node then
srcnode:raisef('unable to convert preprocess value of lua type "%s" to a compile time value', type(val))
end
dest[destpos] = node
dest[destpos] = self:tonode(val, orignode)
end
end

Expand Down
35 changes: 32 additions & 3 deletions nelua/preprocessor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,45 @@ function visitors.PreprocessName(ppcontext, node, emitter, parent, parentindex)
local luacode = node[1]
local pindex, nindex = ppcontext:getregistryindex(parent), ppcontext:getregistryindex(node)
emitter:add_indent_ln('ppregistry[', pindex, '][', parentindex, ']',
'=ppcontext:toname(', luacode, ',ppregistry[', nindex, '])')
'=ppcontext:toname(', luacode, ',ppregistry[', nindex, '])')
end

function visitors.PreprocessExpr(ppcontext, node, emitter, parent, parentindex)
local luacode = node[1]
local pindex, nindex = ppcontext:getregistryindex(parent), ppcontext:getregistryindex(node)
emitter:add_indent_ln('ppcontext:inject_value(',luacode,
',ppregistry[', nindex, '],ppregistry[', pindex,'],',parentindex,')')
',ppregistry[', nindex, '],ppregistry[', pindex,'],', parentindex,')')
end

local function make_expr_node_preprocess(ppcontext, node, emitter)
local luacode = node[1]
local nindex = ppcontext:getregistryindex(node)
emitter:add_indent_ln('ppregistry[', nindex, '].preprocess=function(parent, pindex)')
emitter:inc_indent()
emitter:add_indent_ln('ppcontext:inject_value(', luacode, ', ppregistry[', nindex, '], parent, pindex)')
emitter:dec_indent()
emitter:add_indent_ln('end')
end

function visitors.FuncDef(ppcontext, node, emitter)
local namenode, argnodes, retnodes, annotnodes, blocknode = node[2], node[3], node[4], node[5], node[6]
ppcontext:traverse_node(namenode, emitter, node, 2)
ppcontext:traverse_nodes(argnodes, emitter, node, 3)
for i=1,#retnodes do
local retnode = retnodes[i]
if retnode.tag == 'PreprocessExpr' then
make_expr_node_preprocess(ppcontext, retnode, emitter)
else
ppcontext:traverse_node(retnode, emitter, retnodes, i)
end
end
if annotnodes then
ppcontext:traverse_nodes(annotnodes, emitter, node, 5)
end
ppcontext:traverse_node(blocknode, emitter, node, 6)
end


function visitors.Preprocess(_, node, emitter)
local luacode = node[1]
emitter:add_ln(luacode)
Expand Down Expand Up @@ -85,8 +114,8 @@ local function mark_preprocessing_nodes(ast)
PreprocessExpr = true
}
for _, parents in ast:walk_trace_nodes(preprocess_tags) do
-- mark nearest parent block above
needprocess = true
-- mark nearest parent block above
for i=#parents-1,1,-1 do
local pnode = parents[i]
if pnode.tag == 'Block' then
Expand Down
4 changes: 4 additions & 0 deletions spec/preprocessor_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,10 @@ it("report errors", function()
expect.analyze_error("##[[ invalid() ]]", "attempt to call")
expect.analyze_error("##[[ for ]]", "expected near")
expect.analyze_error("##[[ ast:raisef('ast error') ]]", "ast error")
expect.analyze_error('local function f(x: auto): #[assert(false)]# return x end f(1)',
"while preprocessing function return node")
expect.analyze_error('local function f(x: auto): #[static_assert(false)]# return x end f(1)',
"static assertion")
end)

it("preprocessor replacement", function()
Expand Down
12 changes: 12 additions & 0 deletions spec/typechecker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,18 @@ it("poly function definition", function()
end
local x = R.foo(2)
]])
expect.analyze_ast([[
local function f(x: auto): #[x.type]#
return x
end
local z: integer = f(1)
]])
expect.analyze_error([[
local function f(x: auto): #[x.type]#
return false
end
f(1)
]], "no viable type conversion from")
end)

it("function return", function()
Expand Down

0 comments on commit 5394c97

Please sign in to comment.