Skip to content

Commit

Permalink
Introduce <autocast> annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Feb 9, 2020
1 parent 06eaf86 commit 0e240cd
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/linkedlist.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ local allocator = @generic_allocator
self.tail = nilptr
end

function List:len(): integer
local count = 0
function List:len(): isize
local count: isize = 0
local it = self.head
while it do
count = count + 1
Expand Down
76 changes: 45 additions & 31 deletions lib/myarraytable.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ local allocator = @gc_allocator
impl: ArrayTableImpl*
}

function ArrayTable:_create_impl() <inline>
function ArrayTable:_init() <inline>
if unlikely(not self.impl) then
self.impl = (@ArrayTableImpl*)(allocator.alloc0(#ArrayTableImpl))
end
end

function ArrayTable:_grow()
function ArrayTable:_grow() <noinline>
local cap: usize
if self.impl.data.size ~= 0 then
cap = self.impl.data.size * 2
Expand All @@ -35,16 +35,16 @@ local allocator = @gc_allocator
self.impl.data = allocator.spanrealloc(self.impl.data, cap)
end

function ArrayTable:reserve(n: usize)
self:_create_impl()
function ArrayTable:reserve(n: usize <autocast>) <noinline>
self:_init()
local cap = n + 1
if self.impl.data.size < cap then
self.impl.data = allocator.spanrealloc(self.impl.data, cap)
end
end

function ArrayTable:resize(n: usize, v: T)
self:_create_impl()
function ArrayTable:resize(n: usize <autocast>, v: T)
self:_init()
local addn = n - self.impl.size
if addn > 0 then
self:reserve(n)
Expand All @@ -55,24 +55,38 @@ local allocator = @gc_allocator
end
end

function ArrayTable:push(v: T)
self:_create_impl()
function ArrayTable:clear()
if likely(self.impl) then
self.impl.size = 0
end
end

function ArrayTable:reset()
if likely(self.impl) then
allocator.spandealloc(self.impl.data)
allocator.dealloc(self.impl)
self.impl = nilptr
end
end

function ArrayTable:push(v: T) <inline>
self:_init()
self.impl.size = self.impl.size + 1
if unlikely(self.impl.size + 1 >= self.impl.data.size) then
self:_grow()
end
self.impl.data[self.impl.size] = v
end

function ArrayTable:pop(): T
check(self.impl and self.impl.size > 0, 'arraytable: length is 0')
function ArrayTable:pop(): T <inline>
check(self.impl and self.impl.size > 0, 'arraytable.pop: length is 0')
local i = self.impl.size
self.impl.size = self.impl.size - 1
return self.impl.data[i]
end

function ArrayTable:at(i: usize): T*
self:_create_impl()
function ArrayTable:at(i: usize <autocast>): T* <inline>
self:_init()
if unlikely(i > self.impl.size) then
check(i == self.impl.size + 1, 'arraytable.at: index out of range')
self.impl.size = self.impl.size + 1
Expand All @@ -84,33 +98,33 @@ local allocator = @gc_allocator
return &self.impl.data[i]
end

function ArrayTable:get(i: usize): T*
check(self.impl and i <= self.impl.size, 'arraytable.get: index out of range')
if unlikely(i == 0 and self.impl.data.size == 0) then
self:_grow()
function ArrayTable:set(i: usize <autocast>, v: T) <inline>
self:_init()
if unlikely(i > self.impl.size) then
check(i == self.impl.size + 1, 'arraytable.at: index out of range')
self.impl.size = self.impl.size + 1
end
return &self.impl.data[i]
end

function ArrayTable:length(): usize
if unlikely(not self.impl) then
return 0
if unlikely(self.impl.size + 1 > self.impl.data.size) then
self:_grow()
end
return self.impl.size
self.impl.data[i] = v
end

function ArrayTable:clear()
if likely(self.impl) then
self.impl.size = 0
function ArrayTable:get(i: usize <autocast>): T <inline>
if unlikely(i == 0 and (not self.impl or self.impl.data.size == 0)) then
local v: T
return v
else
check(self.impl and i <= self.impl.size, 'arraytable.get: index out of range')
return self.impl.data[i]
end
end

function ArrayTable:reset()
if likely(self.impl) then
allocator.spandealloc(self.impl.data)
allocator.dealloc(self.impl)
self.impl = nilptr
function ArrayTable:len(): isize <inline>
if unlikely(not self.impl) then
return 0
end
return (@isize)(self.impl.size)
end

## return ArrayTable
Expand Down
35 changes: 25 additions & 10 deletions nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ function visitors.FuncType(context, node)
context:traverse_nodes(argnodes)
context:traverse_nodes(retnodes)
local type = types.FunctionType(node,
tabler.imap(argnodes, function(argnode) return argnode.attr.value end),
tabler.imap(argnodes, function(argnode) return Attr{type = argnode.attr.value} end),
tabler.imap(retnodes, function(retnode) return retnode.attr.value end))
attr.type = primtypes.type
attr.value = type
Expand Down Expand Up @@ -872,25 +872,34 @@ local function visitor_Call(context, node, argnodes, calleetype, calleesym, call
if calleetype:is_function() then
-- function call
local funcargtypes = calleetype.argtypes
local funcargattrs = calleetype.argattrs or calleetype.args
local pseudoargtypes = funcargtypes
local pseudoargattrs = funcargattrs
if calleeobjnode then
pseudoargtypes = tabler.copy(funcargtypes)
pseudoargattrs = tabler.copy(funcargattrs)
local ok, err = funcargtypes[1]:is_convertible_from(calleeobjnode)
if not ok then
node:raisef("in call of function '%s' at argument %d: %s",
calleetype:prettyname(), 1, err)
end
table.remove(pseudoargtypes, 1)
table.remove(pseudoargattrs, 1)
attr.pseudoargtypes = pseudoargtypes
attr.pseudoargattrs = pseudoargtypes
end
if #argnodes > #pseudoargtypes then
if #argnodes > #pseudoargattrs then
node:raisef("in call of function '%s': expected at most %d arguments but got %d",
calleetype:prettyname(), #pseudoargtypes, #argnodes)
calleetype:prettyname(), #pseudoargattrs, #argnodes)
end
local lazyargs = {}
local knownallargs = true
for i,funcargtype,argnode,argtype in izipargnodes(pseudoargtypes, argnodes) do
for i,funcarg,argnode,argtype in izipargnodes(pseudoargattrs, argnodes) do
local arg
local funcargtype
if traits.is_type(funcarg) then funcargtype = funcarg else
funcargtype = funcarg.type
end
if argnode then
argnode.desiredtype = funcargtype
context:traverse_node(argnode)
Expand All @@ -916,7 +925,7 @@ local function visitor_Call(context, node, argnodes, calleetype, calleesym, call
end
if arg then
if funcargtype then
local ok, err = funcargtype:is_convertible_from(arg)
local ok, err = funcargtype:is_convertible_from(arg, traits.is_attr(funcarg) and funcarg.autocast)
if not ok then
node:raisef("in call of function '%s' at argument %d: %s",
calleetype:prettyname(), i, err)
Expand Down Expand Up @@ -1335,7 +1344,7 @@ function visitors.VarDecl(context, node)
if valnode and vartype:is_initializable_from_attr(valnode.attr) then
valnode.attr.initializer = true
end
local ok, err = vartype:is_convertible_from(valnode or valtype)
local ok, err = vartype:is_convertible_from(valnode or valtype, varnode.attr.autocast)
if not ok then
varnode:raisef("in variable '%s' declaration: %s", symbol.name, err)
end
Expand All @@ -1355,8 +1364,9 @@ function visitors.Assign(context, node)
for i,varnode,valnode,valtype in izipargnodes(varnodes, valnodes) do
local symbol = context:traverse_node(varnode)
local vartype = varnode.attr.type
local varattr = varnode.attr
varnode.assign = true
if varnode.attr.const or varnode.attr.comptime then
if varattr.const or varattr.comptime then
varnode:raisef("cannot assign a constant variable")
end
if valnode then
Expand All @@ -1381,7 +1391,7 @@ function visitors.Assign(context, node)
varnode:raisef("variable assignment at index '%d' is assigning to nothing in the expression", i)
end
if vartype and valtype then
local ok, err = vartype:is_convertible_from(valnode or valtype)
local ok, err = vartype:is_convertible_from(valnode or valtype, varattr.autocast)
if not ok then
varnode:raisef("in variable assignment: %s", err)
end
Expand Down Expand Up @@ -1425,15 +1435,20 @@ local function resolve_function_argtypes(symbol, varnode, argnodes, scope, check

for i,argnode in ipairs(argnodes) do
local argattr = argnode.attr
local argtype = argattr.type
if not argtype then
-- function arguments types must be known ahead, fallbacks to any if untyped
local argtype = argattr.type or primtypes.any
argtype = primtypes.any
argattr.type = argtype
end
if checklazy and argtype.lazyable then
islazyparent = true
end
argtypes[i] = argtype
argattrs[i] = argattr
end


if varnode.tag == 'ColonIndex' and symbol and symbol.metafunc then
-- inject 'self' type as first argument
local selfsym = symbol.selfsym
Expand Down Expand Up @@ -1583,7 +1598,7 @@ function visitors.FuncDef(context, node, lazysymbol)
type = types.LazyFunctionType(node, argattrs, returntypes)
end
elseif not returntypes.has_unknown then
type = types.FunctionType(node, argtypes, returntypes)
type = types.FunctionType(node, argattrs, returntypes)
end

if symbol then -- symbol may be nil in case of array/dot index
Expand Down
4 changes: 4 additions & 0 deletions nelua/cemitter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ function CEmitter:add_numeric_literal(valattr, valtype)
valtype = valtype or valattr.type
local val, base = valattr.value, valattr.base

if val:isneg() and valtype:is_unsigned() then
val = valtype:normalize_value(val)
end

local minusone = false
if valtype:is_integral() and valtype:is_signed() and val == valtype.min then
-- workaround C warning `integer constant is so large that it is unsigned`
Expand Down
6 changes: 4 additions & 2 deletions nelua/symdefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ local typedefs = require 'nelua.typedefs'
local tabler = require 'nelua.utils.tabler'
local types = require 'nelua.types'
local Symbol = require 'nelua.symbol'
local Attr = require 'nelua.attr'
local primtypes = typedefs.primtypes

local symdefs = {}

local function define_function(name, args, rets, props)
local type = types.FunctionType(nil, args, rets)
local function define_function(name, argtypes, rettypes, props)
local args = tabler.imap(argtypes, function(argtype) return Attr{type = argtype} end)
local type = types.FunctionType(nil, args, rettypes)
type:suggest_nick(name)
type.sideeffect = false
local symbol = Symbol{
Expand Down
3 changes: 2 additions & 1 deletion nelua/typedefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ typedefs.variable_annots = {
noinit = true,
cexport = true,
comptime = true,
const = true
const = true,
autocast = true,
}

typedefs.type_annots = {
Expand Down
5 changes: 3 additions & 2 deletions nelua/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1002,9 +1002,10 @@ local FunctionType = typeclass()
types.FunctionType = FunctionType
FunctionType.Function = true

function FunctionType:_init(node, argtypes, returntypes)
function FunctionType:_init(node, argattrs, returntypes)
Type._init(self, 'function', cpusize, node)
self.argtypes = argtypes or {}
self.argattrs = argattrs or {}
self.argtypes = tabler.imap(self.argattrs, function(arg) return arg.type end)
self.returntypes = returntypes or {}
self.codename = gencodename(self)
end
Expand Down
36 changes: 36 additions & 0 deletions spec/03-typechecker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,42 @@ it("narrow casting", function()
local i32 = (@int32) (-0x8000000000000000)
local i64 = (@int64) (-0x8000000000000000)
]])

assert.analyze_ast([[
local i: integer <autocast>
local u: uinteger
i = u
]])
assert.analyze_error([[
local i: integer
local u: uinteger
i = u
]], "no viable type conversion")

assert.analyze_ast([[
local u: uinteger
local i: integer <autocast> = u
]])
assert.analyze_error([[
local u: uinteger
local i: integer = u
]], "no viable type conversion")

assert.analyze_ast([[
local function f(u: uinteger <autocast>)
return u
end
local i: integer
f(i)
]])
assert.analyze_error([[
local function f(u: uinteger)
return u
end
local i: integer
f(i)
]], "no viable type conversion")

end)

it("numeric ranges", function()
Expand Down
33 changes: 33 additions & 0 deletions spec/05-cgenerator_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,39 @@ it("automatic dereference", function()
]])
end)

it("automatic casting", function()
assert.generate_c([[
local a = (@uint8)(-1)
local b: uint8 <autocast> = -1
]], {"a = 255U", "b = 255U"})
assert.run_c([[
do
local i8: int8 <autocast>
local u8: uint8 = 255
i8 = u8
assert(i8 == -1)
end
do
local i8: int8 = -1
local u8: uint8 <autocast>
u8 = i8
assert(u8 == 255)
end
local function f(x: uint8 <autocast>)
return x
end
local function g(x: int8 <autocast>)
return x
end
local i: int8 = -1
local u: uint8 = 255
assert(f(i) == 255)
assert(g(u) == -1)
]])
end)

it("nilptr", function()
assert.generate_c("local p: pointer = nilptr", "void* p = NULL")
assert.run_c([[
Expand Down
Loading

0 comments on commit 0e240cd

Please sign in to comment.