Skip to content

Commit

Permalink
Metamethod __convert for records
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Feb 13, 2020
1 parent a3e3ca2 commit 450f1c1
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 34 deletions.
49 changes: 28 additions & 21 deletions lib/myarraytable.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,14 @@
## strict = true
## unitname = 'nelua'

## local memoize = require 'nelua.utils.memoize'
## myarraytable = hygienize(memoize(function(T, allocator)
## if allocator then
local allocator = #[allocator]#
local allocator: type = #[allocator]#
## else
require 'allocators.gc_allocator'
local allocator = @gc_allocator
local allocator: type = @gc_allocator
## end

local an_array_of_T = #[concept(function(x)
if x.type:is_table() then
return types.ArrayType(nil, T, #x.node[1])
elseif x.type:is_array_of(T) then
return true
end
end)]#

local T = @#[T]#
local ArrayTableImplT <codename #['nelua_ArrayTableImpl_'..T.name]#> = @record {
size: usize,
Expand All @@ -28,6 +19,17 @@
local ArrayTableT <codename #['nelua_ArrayTable_'..T.name]#> = @record{
impl: ArrayTableImplT*
}
##[[
ArrayTableT.value.choose_braces_type = function(node)
return types.ArrayType(nil, T, #node[1])
end
]]

local convertible_concept: type = #[concept(function(x)
if x.type:is_array_of(T) then
return true
end
end)]#

function ArrayTableT:init() <inline>
if unlikely(not self.impl) then
Expand Down Expand Up @@ -61,16 +63,15 @@

function ArrayTableT:reserve(n: usize <autocast>)
self:init()
local cap = n + 1
local cap: usize = n + 1
if self.impl.data.size < cap then
self.impl.data = allocator.spanrealloc(self.impl.data, cap)
end
end

function ArrayTableT:resize(n: usize <autocast>, v: T)
self:init()
local addn = n - self.impl.size
if addn > 0 then
if n > self.impl.size then
self:reserve(n)
for i=self.impl.size+1,<n do
self.impl.data[i+1] = v
Expand Down Expand Up @@ -99,17 +100,11 @@

function ArrayTableT:pop(): T <inline>
check(self.impl and self.impl.size > 0, 'arraytable.pop: length is 0')
local i = self.impl.size
local i: usize = self.impl.size
self.impl.size = self.impl.size - 1
return self.impl.data[i]
end

function ArrayTableT:__assign(values: an_array_of_T)
self:reserve(#values)
memory.zero(&self.impl.data[0], #T)
memory.copy(&self.impl.data[1], &values[0], #values * #T)
end

function ArrayTableT:__atindex(i: usize <autocast>): T* <inline>
self:init()
if unlikely(i > self.impl.size) then
Expand All @@ -130,5 +125,17 @@
return (@isize)(self.impl.size)
end

function ArrayTableT.__convert(values: convertible_concept): ArrayTableT <inline>
local self: ArrayTableT
local len: usize = (@usize)(#values)
self:reserve(len)
self.impl.size = len
memory.zero(&self.impl.data[0], #T)
for i:usize=1,len do
self.impl.data[i] = values[i-1]
end
return self
end

## return ArrayTableT
## end))
49 changes: 44 additions & 5 deletions nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ function visitors.Number(context, node)
attr.nofloatsuffix = true
end
attr.base = base
attr.literal = true
attr.comptime = true
end

Expand All @@ -59,6 +60,7 @@ function visitors.String(_, node)
end
attr.type = primtypes.string
attr.value = value
attr.literal = true
attr.comptime = true
end

Expand All @@ -69,13 +71,15 @@ function visitors.Boolean(_, node)
attr.value = value
attr.type = primtypes.boolean
attr.comptime = true
attr.literal = true
end

function visitors.Nil(_, node)
local attr = node.attr
if attr.type then return end
attr.type = primtypes.nilable
attr.comptime = true
attr.literal = true
end

local function visitor_ArrayTable_literal(context, node, littype)
Expand Down Expand Up @@ -192,6 +196,10 @@ end

function visitors.Table(context, node)
local desiredtype = node.desiredtype
node.attr.literal = true
if desiredtype and desiredtype:is_record() and desiredtype.choose_braces_type then
desiredtype = desiredtype.choose_braces_type(node)
end
if not desiredtype or (desiredtype:is_table() or desiredtype.lazyable) then
visitor_Table_literal(context, node)
elseif desiredtype:is_arraytable() then
Expand All @@ -201,7 +209,6 @@ function visitors.Table(context, node)
elseif desiredtype:is_record() then
visitor_Record_literal(context, node, desiredtype)
else
-- concept will be traversed again later
node:raisef("type '%s' cannot be initialized using a table literal", desiredtype:prettyname())
end
end
Expand Down Expand Up @@ -411,6 +418,7 @@ function visitors.TypeInstance(context, node, symbol)
node.attr = attr
if symbol then
attr.value:suggest_nick(symbol.name, symbol.staticstorage and symbol.codename)
attr.value.symbol = symbol
end
end

Expand Down Expand Up @@ -449,6 +457,7 @@ function visitors.RecordType(context, node, symbol)
symbol.type = primtypes.type
symbol.value = recordtype
recordtype:suggest_nick(symbol.name, symbol.staticstorage and symbol.codename)
recordtype.symbol = symbol
end
local fieldnodes = node[1]
context:traverse_nodes(fieldnodes, recordtype)
Expand Down Expand Up @@ -711,6 +720,31 @@ local function visitor_Call_typeassertion(context, node, argnodes, type)
attr.calleetype = primtypes.type
end

local function visitor_convert(context, parent, parentindex, vartype, valnode, valtype)
if not (valtype and vartype and vartype:is_user_record() and vartype ~= valtype) then
-- convert cannot be overridden
return valnode, valtype
end
if valtype:is_pointer_of(vartype) or vartype:is_pointer_of(valtype) then
-- ignore automatic deref/ref
return valnode, valtype
end
local mtsym = vartype:get_metafield('__convert')
if not mtsym then
return valnode, valtype
end
local n = context.parser.astbuilder.aster
assert(vartype.symbol)
local idnode = n.Id{vartype.symbol.name}
local pattr = Attr{foreignsymbol=vartype.symbol}
idnode.attr:merge(pattr)
idnode.pattr = pattr
local newvalnode = n.Call{{valnode}, n.DotIndex{'__convert', idnode}}
parent[parentindex] = newvalnode
context:traverse_node(newvalnode)
return newvalnode, newvalnode.attr.type
end

local function visitor_Call(context, node, argnodes, calleetype, calleesym, calleeobjnode)
local attr = node.attr
if calleetype then
Expand Down Expand Up @@ -749,6 +783,7 @@ local function visitor_Call(context, node, argnodes, calleetype, calleesym, call
argnode.desiredtype = argnode.desiredtype or funcargtype
context:traverse_node(argnode)
argtype = argnode.attr.type
argnode, argtype = visitor_convert(context, argnodes, i, funcargtype, argnode, argtype)
if argtype then
arg = argnode.attr
end
Expand Down Expand Up @@ -1301,7 +1336,7 @@ function visitors.VarDecl(context, node)
node:raisef("extra expressions in declaration, expected at most %d but got %d",
#varnodes, #valnodes)
end
for _,varnode,valnode,valtype in izipargnodes(varnodes, valnodes) do
for i,varnode,valnode,valtype in izipargnodes(varnodes, valnodes) do
assert(varnode.tag == 'IdDecl')
varnode.attr.vardecl = true
if varscope == 'global' then
Expand All @@ -1328,9 +1363,11 @@ function visitors.VarDecl(context, node)
varnode:raisef("const variables must have an initial value")
end
if valnode then
valnode.desiredtype = vartype
valnode.desiredtype = valnode.desiredtype or vartype
context:traverse_node(valnode, symbol)
valtype = valnode.attr.type
valnode, valtype = visitor_convert(context, valnodes, i, vartype, valnode, valtype)

if valtype then
if valtype:is_varanys() then
-- varanys are always stored as any in variables
Expand All @@ -1343,7 +1380,7 @@ function visitors.VarDecl(context, node)
if varnode.attr.comptime and not (valnode.attr.comptime and valtype) then
varnode:raisef("compile time variables can only assign to compile time expressions")
elseif vartype and not valtype and vartype:is_auto() then
valnode:raisef("auto variables must be assigned to expressions where type is known ahead")
varnode:raisef("auto variables must be assigned to expressions where type is known ahead")
elseif varnode.attr.cimport and not
(vartype == primtypes.type or (vartype == nil and valtype == primtypes.type)) then
varnode:raisef("cannot assign imported variables, only imported types can be assigned")
Expand Down Expand Up @@ -1417,6 +1454,7 @@ function visitors.Assign(context, node)
valnode.desiredtype = vartype
context:traverse_node(valnode)
valtype = valnode.attr.type
valnode, valtype = visitor_convert(context, valnodes, i, vartype, valnode, valtype)
end
if valtype then
if valtype:is_void() then
Expand All @@ -1435,7 +1473,8 @@ function visitors.Assign(context, node)
varnode:raisef("variable assignment at index '%d' is assigning to nothing in the expression", i)
end
if vartype and valtype then
local ok, err = vartype:is_convertible_from(valnode or valtype, varattr.autocast)
local from = valnode or valtype
local ok, err = vartype:is_convertible_from(from, varattr.autocast)
if not ok then
varnode:raisef("in variable assignment: %s", err)
end
Expand Down
4 changes: 3 additions & 1 deletion nelua/preprocessor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ local except = require 'nelua.utils.except'
local errorer = require 'nelua.utils.errorer'
local config = require 'nelua.configer'.get()
local stringer = require 'nelua.utils.stringer'
local memoize = require 'nelua.utils.memoize'

local traverse_node = VisitorContext.traverse_node
local function pp_default_visitor(self, node, emitter, ...)
Expand Down Expand Up @@ -198,7 +199,8 @@ function preprocessor.preprocess(context, ast)
local type = types.ConceptType(f)
type.node = context:get_current_node()
return type
end
end,
memoize = memoize
})
setmetatable(ppenv, { __index = function(_, key)
local v = rawget(ppcontext.context.env, key)
Expand Down
18 changes: 15 additions & 3 deletions spec/03-typechecker_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,24 @@ it("analyzed ast transform", function()
n.VarDecl{'local',
{ n.IdDecl{
assign=true,
attr = {codename='a', name='a', staticstorage=true, type='int64', vardecl=true, lvalue=true},
attr = {
codename='a',
name='a',
staticstorage=true,
type='int64',
vardecl=true,
lvalue=true
},
'a' }},
{ n.Number{
attr = {
comptime=true, initializer=true,
base='dec', type='int64', untyped=true, value=bn.fromdec('1')
comptime=true,
initializer=true,
literal=true,
base='dec',
type='int64',
untyped=true,
value=bn.fromdec('1')
},'dec', '1'
}}
}
Expand Down
48 changes: 44 additions & 4 deletions spec/05-cgenerator_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1257,14 +1257,54 @@ it("record metametods", function()
function intarray:__len(): isize <inline>
return #self.data
end
local a: intarray
assert(a[0] == 0 and a[1] == 0)
a[0] = 1 a[1] = 2
assert(a:__atindex(0) == &a.data[0])
assert(a[0] == 1 and a[1] == 2)
assert(#a == 100)
assert(a:__len() == 100)
local R = @record {
x: integer
}
function R.__convert(x: integer): R
local self: R
self.x = x
return self
end
local r: R = 1
assert(r.x == 1)
r = R.__convert(2)
assert(r.x == 2)
r = 3
assert(r.x == 3)
local function f()
local r: R = 1
assert(r.x == 1)
r = R.__convert(2)
assert(r.x == 2)
r = 3
assert(r.x == 3)
end
f()
local function g(r: R)
return r.x
end
assert(g(r) == 3)
assert(g(4) == 4)
local R = @record {
x: integer[2]
}
## R.value.choose_braces_type = function() return types.ArrayType(nil, integer, 2) end
function R.__convert(x: auto): R
local self: R
self.x = x
return self
end
local r: R = {1,2}
assert(r.x[0] == 1 and r.x[1] == 2)
]])
end)

Expand Down Expand Up @@ -1908,13 +1948,13 @@ it("concepts", function()
local R = @record {
x: integer
}
function R:__assign(x: an_arithmetic)
function R:__convert(x: an_arithmetic)
self.x = x
end
local r: R
R.__assign(&r, 1)
R.__convert(&r, 1)
assert(r.x == 1)
r:__assign(2)
r:__convert(2)
assert(r.x == 2)
]=])
end)
Expand Down

0 comments on commit 450f1c1

Please sign in to comment.