Skip to content

Commit

Permalink
Generics system!
Browse files Browse the repository at this point in the history
  • Loading branch information
edubart committed Feb 13, 2020
1 parent 450f1c1 commit 5504266
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 39 deletions.
5 changes: 4 additions & 1 deletion lib/myarraytable.nelua
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## strict = true
## unitname = 'nelua'

## myarraytable = hygienize(memoize(function(T, allocator)
## local myarraytable = hygienize(memoize(function(T, allocator)
## staticassert(traits.is_type(T), "invalid type '%s'", T)
## if allocator then
local allocator: type = #[allocator]#
## else
Expand Down Expand Up @@ -139,3 +140,5 @@

## return ArrayTableT
## end))

global myarraytable = #[generic(myarraytable)]#
44 changes: 42 additions & 2 deletions nelua/analyzer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ function visitors.Type(context, node)
attr.value = value
end


function visitors.TypeInstance(context, node, symbol)
local typenode = node[1]
if node.attr.type then return end
Expand Down Expand Up @@ -594,6 +593,42 @@ function visitors.PointerType(context, node)
attr.type = primtypes.type
end

function visitors.GenericType(context, node)
local attr = node.attr
local name, argnodes = node[1], node[2]
if attr.type then return end
local symbol = context.scope:get_symbol(name)
if not symbol or not symbol.type or not symbol.type:is_type() or symbol.type:is_generic() then
node:raisef("symbol '%s' doesn't hold a generic type", name)
end
local params = {}
for i=1,#argnodes do
local argnode = argnodes[i]
context:traverse_node(argnode)
local argattr = argnode.attr
if not (argattr.comptime or argattr.type:is_comptime()) then
node:raisef("in generic '%s': argument #%d isn't a compile time value", name, i)
end
local value = argattr.value
if traits.is_bignumber(value) then
value = value:tonumber()
elseif not (traits.is_type(value) or
traits.is_string(value) or
traits.is_boolean(value) or
traits.is_bignumber(value)) then
node:raisef("in generic '%s': argument #%d of type '%s' is invalid for generics",
name, i, argattr.type:prettyname())
end
params[i] = value
end
local type, err = symbol.value:eval_type(params)
if err then
node:raisef(err)
end
attr.type = primtypes.type
attr.value = symbol.value:eval_type(params)
end

local function iargnodes(argnodes)
local i = 0
local lastargindex = #argnodes
Expand Down Expand Up @@ -694,8 +729,11 @@ end
local function visitor_Call_typeassertion(context, node, argnodes, type)
local attr = node.attr
assert(type)
if type:is_generic() then
node:raisef("assertion to generic '%s': cannot do assertion on generics", type:prettyname())
end
if #argnodes ~= 1 then
node:raisef("assertion to type '%s' expected one argument, but got %d",
node:raisef("assertion to type '%s': expected one argument, but got %d",
type:prettyname(), #argnodes)
end
local argnode = argnodes[1]
Expand Down Expand Up @@ -1405,6 +1443,8 @@ function visitors.VarDecl(context, node)
assert(valnode and valnode.attr.value)
assignvaltype = vartype ~= valtype
symbol.value = valnode.attr.value
symbol.value:suggest_nick(symbol.name, symbol.staticstorage and symbol.codename)
symbol.value.symbol = symbol
end

if vartype and vartype:is_auto() then
Expand Down
4 changes: 4 additions & 0 deletions nelua/astdefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ astbuilder:register('SpanType', {
astbuilder:register('RangeType', {
ntypes.Node, -- subtype typexpr
})
astbuilder:register('GenericType', {
stypes.string + ntypes.PreprocessName, -- generic name
stypes.array_of(ntypes.Node), -- list of typexpr or param expr
})

-- function
astbuilder:register('Function', {
Expand Down
6 changes: 6 additions & 0 deletions nelua/preprocessor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ function preprocessor.preprocess(context, ast)
aster = aster,
config = config,
types = types,
traits = traits,
primtypes = primtypes
}
tabler.update(ppenv, {
Expand Down Expand Up @@ -200,6 +201,11 @@ function preprocessor.preprocess(context, ast)
type.node = context:get_current_node()
return type
end,
generic = function(f)
local type = types.GenericType(f)
type.node = context:get_current_node()
return type
end,
memoize = memoize
})
setmetatable(ppenv, { __index = function(_, key)
Expand Down
82 changes: 46 additions & 36 deletions nelua/syntaxdefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,15 @@ local function get_parser(std)
expr_list <- (expr (%COMMA expr)*)?
eexpr_list <- eexpr (%COMMA expr)*
type_or_param_expr <- typexpr / expr
etype_or_param_expr_list <- etype_or_param_expr (%COMMA type_or_param_expr)*
type_param_expr <-
%cNUMBER /
id /
ppexpr /
(%LPAREN eexpr eRPAREN)
annot_list <- %LANGLE {| (eannot_expr (%COMMA eannot_expr)*) |} eRANGLE
eannot_expr <-
Expand All @@ -468,12 +477,13 @@ local function get_parser(std)
span_type /
range_type /
pointer_type /
generic_type /
primtype /
ppexpr
unary_typexpr_op <-
{| {} %MUL -> 'PointerType' |} /
{| {} %LBRACKET -> 'ArrayType' cnil typexpr_param_expr eRBRACKET |}
{| {} %LBRACKET -> 'ArrayType' cnil etype_param_expr eRBRACKET |}
func_type <- (
{} '' -> 'FuncType'
Expand All @@ -484,46 +494,44 @@ local function get_parser(std)
{| (%COLON (%LPAREN etypexpr_list eRPAREN / etypexpr))? |}
) -> to_astnode
typexpr_param_expr <-
%cNUMBER /
id /
ppexpr /
(%LPAREN eexpr eRPAREN) /
%{ExpectedExpression}
record_type <- ({} %TRECORD -> 'RecordType' eLCURLY
record_type <- ({} %TRECORD -> 'RecordType' %LCURLY
{| (record_field (%SEPARATOR record_field)* %SEPARATOR?)? |}
eRCURLY) -> to_astnode
record_field <- ({} '' -> 'RecordFieldType'
name eCOLON etypexpr
) -> to_astnode
enum_type <- ({} %TENUM -> 'EnumType'
((%LPAREN eprimtype eRPAREN) / cnil) eLCURLY
((%LPAREN eprimtype eRPAREN) / cnil) %LCURLY
{| eenumfield (%SEPARATOR enumfield)* %SEPARATOR? |}
eRCURLY) -> to_astnode
enumfield <- ({} '' -> 'EnumFieldType'
name (%ASSIGN eexpr)?
) -> to_astnode
arraytable_type <- (
{} 'arraytable' -> 'ArrayTableType'
eLPAREN etypexpr eRPAREN
%LPAREN etypexpr eRPAREN
) -> to_astnode
span_type <- (
{} 'span' -> 'SpanType'
eLPAREN etypexpr eRPAREN
%LPAREN etypexpr eRPAREN
) -> to_astnode
range_type <- (
{} 'range' -> 'RangeType'
eLPAREN etypexpr eRPAREN
%LPAREN etypexpr eRPAREN
) -> to_astnode
array_type <- (
{} 'array' -> 'ArrayType'
eLPAREN etypexpr eCOMMA typexpr_param_expr eRPAREN
%LPAREN etypexpr eCOMMA etype_param_expr eRPAREN
) -> to_astnode
pointer_type <- (
{} 'pointer' -> 'PointerType'
((%LPAREN etypexpr eRPAREN) / %SKIP)
) -> to_astnode
generic_type <- (
{} '' -> 'GenericType'
name %LPAREN {| etype_or_param_expr_list |} eRPAREN
) -> to_astnode
primtype <- ({} '' -> 'Type' name) -> to_astnode
ppexpr <- ({} %LPPEXPR -> 'PreprocessExpr' {expr -> 0} eRPPEXPR) -> to_astnode
Expand Down Expand Up @@ -574,28 +582,30 @@ local function get_parser(std)

-- syntax expected captures with errors
grammar:set_pegs([[
eRPAREN <- %RPAREN / %{UnclosedParenthesis}
eRBRACKET <- %RBRACKET / %{UnclosedBracket}
eRPPEXPR <- %RPPEXPR / %{UnclosedBracket}
eRPPNAME <- %RPPNAME / %{UnclosedParenthesis}
eRCURLY <- %RCURLY / %{UnclosedCurly}
eRANGLE <- %RANGLE / %{UnclosedAngle}
eLPAREN <- %LPAREN / %{ExpectedParenthesis}
eLCURLY <- %LCURLY / %{ExpectedCurly}
eLANGLE <- %LANGLE / %{ExpectedAngle}
eLBRACKET <- %LBRACKET / %{ExpectedBracket}
eCOLON <- %COLON / %{ExpectedColon}
eCOMMA <- %COMMA / %{ExpectedComma}
eEND <- %END / %{ExpectedEnd}
eTHEN <- %THEN / %{ExpectedThen}
eUNTIL <- %UNTIL / %{ExpectedUntil}
eDO <- %DO / %{ExpectedDo}
ename <- name / %{ExpectedName}
eexpr <- expr / %{ExpectedExpression}
etypexpr <- typexpr / %{ExpectedTypeExpression}
ecallargs <- callargs / %{ExpectedCall}
eenumfield <- enumfield / %{ExpectedEnumFieldType}
eprimtype <- primtype / %{ExpectedPrimitiveTypeExpression}
eRPAREN <- %RPAREN / %{UnclosedParenthesis}
eRBRACKET <- %RBRACKET / %{UnclosedBracket}
eRPPEXPR <- %RPPEXPR / %{UnclosedBracket}
eRPPNAME <- %RPPNAME / %{UnclosedParenthesis}
eRCURLY <- %RCURLY / %{UnclosedCurly}
eRANGLE <- %RANGLE / %{UnclosedAngle}
eLPAREN <- %LPAREN / %{ExpectedParenthesis}
eLCURLY <- %LCURLY / %{ExpectedCurly}
eLANGLE <- %LANGLE / %{ExpectedAngle}
eLBRACKET <- %LBRACKET / %{ExpectedBracket}
eCOLON <- %COLON / %{ExpectedColon}
eCOMMA <- %COMMA / %{ExpectedComma}
eEND <- %END / %{ExpectedEnd}
eTHEN <- %THEN / %{ExpectedThen}
eUNTIL <- %UNTIL / %{ExpectedUntil}
eDO <- %DO / %{ExpectedDo}
ename <- name / %{ExpectedName}
eexpr <- expr / %{ExpectedExpression}
etypexpr <- typexpr / %{ExpectedTypeExpression}
etype_or_param_expr <- type_or_param_expr / %{ExpectedExpression}
etype_param_expr <- type_param_expr / %{ExpectedExpression}
ecallargs <- callargs / %{ExpectedCall}
eenumfield <- enumfield / %{ExpectedEnumFieldType}
eprimtype <- primtype / %{ExpectedPrimitiveTypeExpression}
]])

-- compile whole grammar
Expand Down
38 changes: 38 additions & 0 deletions nelua/types.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ local sstream = require 'nelua.utils.sstream'
local metamagic = require 'nelua.utils.metamagic'
local config = require 'nelua.configer'.get()
local bn = require 'nelua.utils.bn'
local except = require 'nelua.utils.except'
local typedefs, primtypes

local types = {}
Expand Down Expand Up @@ -188,6 +189,7 @@ function Type:is_unsigned() return self.unsigned end
function Type:is_signed() return self.arithmetic and not self.unsigned end
function Type:is_generic_pointer() return self.genericpointer end
function Type:is_concept() return self.concept end
function Type:is_generic() return self.generic end
function Type.is_pointer_of() return false end
function Type.is_array_of() return false end
function Type.has_pointer() return false end
Expand Down Expand Up @@ -1427,6 +1429,42 @@ function ConceptType:is_convertible_from_attr(attr, explicit)
return type, err
end

--------------------------------------------------------------------------------
local GenericType = typeclass()
types.GenericType = GenericType
GenericType.nodecl = true
GenericType.nolvalue = true
GenericType.comptime = true
GenericType.unpointable = true
GenericType.generic = true

function GenericType:_init(func)
Type._init(self, 'generic', 0)
self.func = func
end

function GenericType:eval_type(params)
local ret
local ok, err = except.trycall(function()
ret = self.func(tabler.unpack(params))
end)
if err then
return nil, err
end
if traits.is_symbol(ret) then
if not ret.type or not ret.type:is_type() then
ret = nil
err = stringer.pformat("expected a symbol holding a type in generic return, but got something else")
else
ret = ret.value
end
elseif not traits.is_type(ret) then
ret = nil
err = stringer.pformat("expected a type or symbol in generic return, but got '%s'", type(ret))
end
return ret, err
end

--------------------------------------------------------------------------------
function types.set_typedefs(t)
typedefs = t
Expand Down
9 changes: 9 additions & 0 deletions nelua/utils/bn.lua
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ function bn.tointeger(v)
end
end

local orig_tonumber = bn.tonumber
function bn.tonumber(v)
local vint = v:trunc()
if v == vint then
return tonumber(tostring(vint))
end
return orig_tonumber(v)
end

--------------------------------------------------------------------------------
-- Utilities

Expand Down
15 changes: 15 additions & 0 deletions spec/02-syntaxdefs_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,21 @@ describe("type expression", function()
{ n.IdDecl{'r', n.RangeType{n.Type{'integer'}}}}
}}})
end)
it("generic type", function()
assert.parse_ast(nelua_parser, "local r: somegeneric(integer, 4)",
n.Block{{
n.VarDecl{'local', {
n.IdDecl{'r', n.GenericType{"somegeneric", {
n.Type{'integer'}, n.Number{"dec", "4"}}}}}
}}})
assert.parse_ast(nelua_parser, "local r: somegeneric(span(integer), integer*)",
n.Block{{
n.VarDecl{'local', {
n.IdDecl{'r', n.GenericType{"somegeneric", {
n.SpanType{n.Type{'integer'}},
n.PointerType{n.Type{"integer"}}
}}}}}}})
end)
it("complex types", function()
assert.parse_ast(nelua_parser, "local p: integer*[10]*[10]",
n.Block{{
Expand Down
Loading

0 comments on commit 5504266

Please sign in to comment.