Skip to content

Commit

Permalink
Allow type argument son lazy functions
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Jan 27, 2020
1 parent a1ea213 commit 7a86315
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 83 deletions.
7 changes: 3 additions & 4 deletions nelua/cbuiltins.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ end

-- nil
function builtins.nelua_nilable(context)
define_builtin(context, 'nelua_nilable', "typedef struct nelua_nilable {uint8_t dummy;} nelua_nilable;\n")
define_builtin(context, 'nelua_nilable', "typedef void* nelua_nilable;\n")
end

function builtins.NELUA_NIL(context)
context:ensure_runtime_builtin('nelua_nilable')
define_builtin(context, 'NELUA_NIL', "#define NELUA_NIL (nelua_nilable){0}\n")
function builtins.nelua_unusedvar(context)
define_builtin(context, 'nelua_unusedvar', "typedef void* nelua_unusedvar;\n")
end

-- panic
Expand Down
11 changes: 7 additions & 4 deletions nelua/cemitter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ function CEmitter:zeroinit(type)
s = '0U'
elseif type:is_arithmetic() then
s = '0'
elseif type:is_pointer() then
elseif type:is_pointer() or type:is_nil() or type:is_comptime() then
s = 'NULL'
elseif type:is_boolean() then
s = 'false'
elseif type:is_nil() then
s = self:add_builtin('NELUA_NIL')
else
s = '{0}'
end
Expand Down Expand Up @@ -138,6 +136,11 @@ function CEmitter:add_cstring2string(val)
end

function CEmitter:add_val2type(type, val, valtype)
if type:is_comptime() then
self:add('NULL')
return
end

if not valtype and traits.is_astnode(val) then
valtype = val.attr.type
end
Expand Down Expand Up @@ -181,7 +184,7 @@ function CEmitter:add_val2type(type, val, valtype)
end

function CEmitter:add_nil_literal()
self:add_builtin('NELUA_NIL')
self:add('NULL')
end

function CEmitter:add_numeric_literal(valattr, valtype)
Expand Down
21 changes: 13 additions & 8 deletions nelua/cgenerator.lua
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ visitors.PointerType = visitors.Type
function visitors.IdDecl(context, node, emitter)
local attr = node.attr
local type = attr.type
assert(not (type:is_comptime() or attr.comptime))
assert(not attr.comptime)
if type:is_comptime() then
emitter:add(context:ensure_runtime_builtin('nelua_unusedvar'), ' ', context:declname(node))
return
end
if attr.funcdecl then
emitter:add(context:declname(node))
return
Expand Down Expand Up @@ -487,8 +491,8 @@ local function visitor_Call(context, node, emitter, argnodes, callee, isblockcal
if attr.pointercall then
emitter:add('(*')
end
if attr.lazysym then
emitter:add(context:declname(attr.lazysym))
if attr.lazyeval then
emitter:add(context:declname(attr.lazyeval.node.attr))
else
emitter:add(callee)
end
Expand Down Expand Up @@ -831,17 +835,18 @@ function visitors.Assign(context, node, emitter)
end

function visitors.FuncDef(context, node, emitter)
if node.lazys then
for _,lazysym in ipairs(node.lazys) do
emitter:add(lazysym.lazynode)
local attr = node.attr
local type = attr.type

if type:is_lazyfunction() then
for _,lazyeval in ipairs(type.evals) do
emitter:add(lazyeval.node)
end
return
end

local varscope, varnode, argnodes, retnodes, annotnodes, blocknode = node:args()

local attr = node.attr
local type = attr.type
local numrets = type:get_return_count()
local qualifier = ''
if not attr.entrypoint and not attr.nostatic and not attr.cexport then
Expand Down
77 changes: 40 additions & 37 deletions nelua/typechecker.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local iters = require 'nelua.utils.iterators'
local traits = require 'nelua.utils.traits'
local tabler = require 'nelua.utils.tabler'
local errorer = require 'nelua.utils.errorer'
local typedefs = require 'nelua.typedefs'
local Context = require 'nelua.context'
local Symbol = require 'nelua.symbol'
Expand Down Expand Up @@ -369,9 +370,10 @@ function visitors.Type(context, node)
local value = typedefs.primtypes[tyname]
if not value then
local symbol = context.scope:get_symbol(tyname)
if not (symbol and symbol.type == primtypes.type and symbol.value) then
if not (symbol and symbol.type == primtypes.type) then
node:raisef("symbol '%s' is an invalid type", tyname)
end
errorer.assertf(symbol.value, "symbol '%s' is a type with unknown value", tyname)
value = symbol.value
end
attr.type = primtypes.type
Expand Down Expand Up @@ -856,47 +858,48 @@ local function visitor_Call(context, node, argnodes, calleetype, methodcalleenod
node:raisef("in call of function '%s': expected at most %d arguments but got %d",
calleetype:prettyname(), #pseudoargtypes, #argnodes)
end
local argtypes = {}
local args = {}
local knownallargs = true
for i,funcargtype,argnode,argtype in izipargnodes(pseudoargtypes, argnodes) do
local arg
if argnode then
argnode.desiredtype = funcargtype
context:traverse(argnode)
argtype = argnode.attr.type
if argtype then
arg = argnode.attr
end
else
arg = argtype
end
if argtype and argtype:is_nil() and not funcargtype:is_nilable() then
node:raisef("in call of function '%s': expected an argument at index %d but got nothing",
calleetype:prettyname(), i)
end
if argtype then
if arg then
if funcargtype then
local ok, err = funcargtype:is_convertible_from(argnode or argtype)
local ok, err = funcargtype:is_convertible_from(arg)
if not ok then
node:raisef("in call of function '%s' at argument %d: %s",
calleetype:prettyname(), i, err)
end
end
argtypes[i] = argtype
args[i] = arg
else
knownallargs = false
end
end
if methodcalleenode then
tabler.insert(argtypes, funcargtypes[1])
tabler.insert(args, funcargtypes[1])
end
if calleetype.lazyfunction then
local lazycalleetype = calleetype
calleetype = nil
if knownallargs then
local lazysym, err = lazycalleetype:eval_lazy_for_argtypes(argtypes)
if err then --luacov:disable
--TODO: actually this error is impossible because of the previous check
node:raisef("in call of function '%s': %s", lazycalleetype:prettyname(), err)
end --luacov:enable

if traits.is_attr(lazysym) and lazysym.type then
calleetype = lazysym.type
attr.lazysym = lazysym
local lazyeval = lazycalleetype:eval_lazy_for_args(args)
if lazyeval and lazyeval.node and lazyeval.node.attr.type then
attr.lazyeval = lazyeval
calleetype = lazyeval.node.attr.type
else
lazycalleetype.node.attr.delayresolution = true
end
Expand Down Expand Up @@ -1332,18 +1335,20 @@ function visitors.Return(context, node)
end
end

local function resolve_function_argtypes(symbol, varnode, argnodes, scope)
local function resolve_function_argtypes(symbol, varnode, argnodes, scope, checklazy)
local islazyparent = false
local argattrs = {}
local argtypes = {}

for i,argnode in ipairs(argnodes) do
local argattr = argnode.attr
-- function arguments types must be known ahead, fallbacks to any if untyped
local argtype = argattr.type or primtypes.any
if argtype.lazyable or argattr.comptime then
if checklazy and (argtype.lazyable or argattr.comptime) then
islazyparent = true
end
argtypes[i] = argtype
argattrs[i] = argattr
end

if varnode.tag == 'ColonIndex' and symbol and symbol.metafunc then
Expand All @@ -1358,7 +1363,7 @@ local function resolve_function_argtypes(symbol, varnode, argnodes, scope)
end
end

return argtypes, islazyparent
return argattrs, argtypes, islazyparent
end

local function block_endswith_return(blocknode)
Expand Down Expand Up @@ -1456,11 +1461,11 @@ function visitors.FuncDef(context, node, lazysymbol)
local returntypes = visitor_FuncDef_returns(context, node.attr.type, retnodes)

-- repeat scope to resolve function variables and return types
local islazyparent, argtypes
local islazyparent, argtypes, argattrs
local funcscope = context:repeat_scope_until_resolution('function', function(scope)
scope.returntypes = returntypes
context:traverse(argnodes)
argtypes, islazyparent = resolve_function_argtypes(symbol, varnode, argnodes, scope)
argattrs, argtypes, islazyparent = resolve_function_argtypes(symbol, varnode, argnodes, scope, not lazysymbol)

if not islazyparent then
-- lazy functions never traverse the blocknode by itself
Expand All @@ -1477,7 +1482,7 @@ function visitors.FuncDef(context, node, lazysymbol)
if islazyparent then
assert(not lazysymbol)
if not type then
type = types.LazyFunctionType(node, argtypes, returntypes)
type = types.LazyFunctionType(node, argattrs, returntypes)
end
elseif not returntypes.has_unknown then
type = types.FunctionType(node, argtypes, returntypes)
Expand Down Expand Up @@ -1539,27 +1544,25 @@ function visitors.FuncDef(context, node, lazysymbol)

-- traverse lazy function nodes
if islazyparent then
for i,lazy in ipairs(node.lazys) do
local lazysym, lazyargtypes, lazynode
if traits.is_attr(lazy) then
lazysym = lazy
else
lazyargtypes = lazy
end
if not lazysym then
for _,lazyeval in ipairs(type.evals) do
local lazynode = lazyeval.node
if not lazynode then
lazynode = node:clone()
lazynode.attr.lazynode = lazynode
lazynode.attr.lazyargtypes = lazyargtypes
lazyeval.node = lazynode
local lazyargnodes = lazynode[3]
for j,lazyargtype in ipairs(lazyargtypes) do
lazyargnodes[j].attr.type = lazyargtype
for j,lazyarg in ipairs(lazyeval.args) do
local lazyargattr = lazyargnodes[j].attr
if traits.is_attr(lazyarg) then
lazyargattr.type = lazyarg.type
lazyargattr.value = lazyarg.value
else
lazyargattr.type = lazyarg
end
assert(traits.is_type(lazyargattr.type))
end
else
lazynode = lazysym.lazynode
end
context:traverse(lazynode, symbol)
assert(lazynode.attr._symbol)
node.lazys[i] = lazynode.attr
assert(traits.is_symbol(lazynode.attr))
end
end
end
Expand Down
54 changes: 28 additions & 26 deletions nelua/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ function Type:is_string() return self.string end
function Type:is_cstring() return self.cstring end
function Type:is_record() return self.record end
function Type:is_function() return self.Function end
function Type:is_lazyfunction() return self.lazyfunction end
function Type:is_boolean() return self.boolean end
function Type:is_table() return self.table end
function Type:is_array() return self.array end
Expand Down Expand Up @@ -1038,41 +1039,42 @@ types.LazyFunctionType = LazyFunctionType
LazyFunctionType.Function = true
LazyFunctionType.lazyfunction = true

function LazyFunctionType:_init(node, argtypes, returntypes)
function LazyFunctionType:_init(node, args, returntypes)
Type._init(self, 'lazyfunction', 0, node)
if not node.lazys then
node.lazys = {}
end
self.argtypes = argtypes or {}
self.args = args or {}
self.argtypes = tabler.imap(self.args, function(arg) return arg.type end)
self.returntypes = returntypes or {}
self.evals = {}
self.codename = gencodename(self)
end

function LazyFunctionType:get_lazy(argtypes)
for _,lazy in ipairs(self.node.lazys) do
local lazyargtypes
if traits.is_attr(lazy) then
lazyargtypes = lazy.lazyargtypes
else
lazyargtypes = lazy
local function lazy_args_matches(largs, rargs)
for _,larg,rarg in iters.izip(largs, rargs) do
--TODO: if traits.is_attr(larg) and traits.is_attr(rargs)
local ltype = traits.is_attr(larg) and larg.type or larg
local rtype = traits.is_attr(rarg) and rarg.type or rarg
if ltype ~= rtype then
return false
end
if tabler.deepcompare(lazyargtypes, argtypes) then
return lazy
end
return true
end

function LazyFunctionType:get_lazy_eval(args)
for _,lazyeval in ipairs(self.evals) do
if lazy_args_matches(lazyeval.args, args) then
return lazyeval
end
end
end

function LazyFunctionType:eval_lazy_for_argtypes(argtypes)
local lazy = self:get_lazy(argtypes)
if not lazy then
local ok, err = types.are_types_convertible(argtypes, self.argtypes)
if not ok then --luacov:disable
return nil, 'in lazy function evaluation: ' .. err
end --luacov:enable
lazy = argtypes
table.insert(self.node.lazys, lazy)
function LazyFunctionType:eval_lazy_for_args(args)
local lazyeval = self:get_lazy_eval(args)
if not lazyeval then
lazyeval = { args = args }
table.insert(self.evals, lazyeval)
end
return lazy
return lazyeval
end

LazyFunctionType.is_equal = FunctionType.is_equal
Expand Down Expand Up @@ -1347,8 +1349,8 @@ end

--TODO: refactor to use this function
--luacov:disable
function types.are_types_convertible(atypes, btypes)
for i,atype,btype in iters.izip(atypes, btypes) do
function types.are_types_convertible(largs, rargs)
for i,atype,btype in iters.izip(largs, rargs) do
if atype and btype then
local ok, err = btype:is_convertible_from(atype)
if not ok then
Expand Down
4 changes: 4 additions & 0 deletions nelua/utils/traits.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ function traits.is_attr(v)
return type(v) == 'table' and v._attr
end

function traits.is_symbol(v)
return type(v) == 'table' and v._symbol
end

function traits.is_type(v)
return type(v) == 'table' and v._type
end
Expand Down
13 changes: 12 additions & 1 deletion spec/03-typechecker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,18 @@ it("lazy function definition", function()
local a = 1
f(a)
]])

assert.analyze_ast([[
local function f(T: type): integer
return 1
end
f(@number)
]])
assert.analyze_ast([[
local function cast(T: type, value: auto)
return (@T)(value)
end
local a = cast(@number, 1)
]])
end)

it("function return", function()
Expand Down
Loading

0 comments on commit 7a86315

Please sign in to comment.