diff --git a/lib/myarraytable.nelua b/lib/myarraytable.nelua index b197dd01..85525d95 100644 --- a/lib/myarraytable.nelua +++ b/lib/myarraytable.nelua @@ -3,7 +3,8 @@ ## strict = true ## unitname = 'nelua' -## myarraytable = hygienize(memoize(function(T, allocator) +## local myarraytable = hygienize(memoize(function(T, allocator) + ## staticassert(traits.is_type(T), "invalid type '%s'", T) ## if allocator then local allocator: type = #[allocator]# ## else @@ -139,3 +140,5 @@ ## return ArrayTableT ## end)) + +global myarraytable = #[generic(myarraytable)]# diff --git a/nelua/analyzer.lua b/nelua/analyzer.lua index b6e96eee..7c24f605 100644 --- a/nelua/analyzer.lua +++ b/nelua/analyzer.lua @@ -408,7 +408,6 @@ function visitors.Type(context, node) attr.value = value end - function visitors.TypeInstance(context, node, symbol) local typenode = node[1] if node.attr.type then return end @@ -594,6 +593,42 @@ function visitors.PointerType(context, node) attr.type = primtypes.type end +function visitors.GenericType(context, node) + local attr = node.attr + local name, argnodes = node[1], node[2] + if attr.type then return end + local symbol = context.scope:get_symbol(name) + if not symbol or not symbol.type or not symbol.type:is_type() or symbol.type:is_generic() then + node:raisef("symbol '%s' doesn't hold a generic type", name) + end + local params = {} + for i=1,#argnodes do + local argnode = argnodes[i] + context:traverse_node(argnode) + local argattr = argnode.attr + if not (argattr.comptime or argattr.type:is_comptime()) then + node:raisef("in generic '%s': argument #%d isn't a compile time value", name, i) + end + local value = argattr.value + if traits.is_bignumber(value) then + value = value:tonumber() + elseif not (traits.is_type(value) or + traits.is_string(value) or + traits.is_boolean(value) or + traits.is_bignumber(value)) then + node:raisef("in generic '%s': argument #%d of type '%s' is invalid for generics", + name, i, argattr.type:prettyname()) + end + params[i] = value + end + local type, err = symbol.value:eval_type(params) + if err then + node:raisef(err) + end + attr.type = primtypes.type + attr.value = symbol.value:eval_type(params) +end + local function iargnodes(argnodes) local i = 0 local lastargindex = #argnodes @@ -694,8 +729,11 @@ end local function visitor_Call_typeassertion(context, node, argnodes, type) local attr = node.attr assert(type) + if type:is_generic() then + node:raisef("assertion to generic '%s': cannot do assertion on generics", type:prettyname()) + end if #argnodes ~= 1 then - node:raisef("assertion to type '%s' expected one argument, but got %d", + node:raisef("assertion to type '%s': expected one argument, but got %d", type:prettyname(), #argnodes) end local argnode = argnodes[1] @@ -1405,6 +1443,8 @@ function visitors.VarDecl(context, node) assert(valnode and valnode.attr.value) assignvaltype = vartype ~= valtype symbol.value = valnode.attr.value + symbol.value:suggest_nick(symbol.name, symbol.staticstorage and symbol.codename) + symbol.value.symbol = symbol end if vartype and vartype:is_auto() then diff --git a/nelua/astdefs.lua b/nelua/astdefs.lua index 050576f5..a1ab4f2a 100644 --- a/nelua/astdefs.lua +++ b/nelua/astdefs.lua @@ -123,6 +123,10 @@ astbuilder:register('SpanType', { astbuilder:register('RangeType', { ntypes.Node, -- subtype typexpr }) +astbuilder:register('GenericType', { + stypes.string + ntypes.PreprocessName, -- generic name + stypes.array_of(ntypes.Node), -- list of typexpr or param expr +}) -- function astbuilder:register('Function', { diff --git a/nelua/preprocessor.lua b/nelua/preprocessor.lua index 18d9372e..d6bb2b42 100644 --- a/nelua/preprocessor.lua +++ b/nelua/preprocessor.lua @@ -140,6 +140,7 @@ function preprocessor.preprocess(context, ast) aster = aster, config = config, types = types, + traits = traits, primtypes = primtypes } tabler.update(ppenv, { @@ -200,6 +201,11 @@ function preprocessor.preprocess(context, ast) type.node = context:get_current_node() return type end, + generic = function(f) + local type = types.GenericType(f) + type.node = context:get_current_node() + return type + end, memoize = memoize }) setmetatable(ppenv, { __index = function(_, key) diff --git a/nelua/syntaxdefs.lua b/nelua/syntaxdefs.lua index fa3716c5..d880c604 100644 --- a/nelua/syntaxdefs.lua +++ b/nelua/syntaxdefs.lua @@ -443,6 +443,15 @@ local function get_parser(std) expr_list <- (expr (%COMMA expr)*)? eexpr_list <- eexpr (%COMMA expr)* + type_or_param_expr <- typexpr / expr + etype_or_param_expr_list <- etype_or_param_expr (%COMMA type_or_param_expr)* + + type_param_expr <- + %cNUMBER / + id / + ppexpr / + (%LPAREN eexpr eRPAREN) + annot_list <- %LANGLE {| (eannot_expr (%COMMA eannot_expr)*) |} eRANGLE eannot_expr <- @@ -468,12 +477,13 @@ local function get_parser(std) span_type / range_type / pointer_type / + generic_type / primtype / ppexpr unary_typexpr_op <- {| {} %MUL -> 'PointerType' |} / - {| {} %LBRACKET -> 'ArrayType' cnil typexpr_param_expr eRBRACKET |} + {| {} %LBRACKET -> 'ArrayType' cnil etype_param_expr eRBRACKET |} func_type <- ( {} '' -> 'FuncType' @@ -484,21 +494,14 @@ local function get_parser(std) {| (%COLON (%LPAREN etypexpr_list eRPAREN / etypexpr))? |} ) -> to_astnode - typexpr_param_expr <- - %cNUMBER / - id / - ppexpr / - (%LPAREN eexpr eRPAREN) / - %{ExpectedExpression} - - record_type <- ({} %TRECORD -> 'RecordType' eLCURLY + record_type <- ({} %TRECORD -> 'RecordType' %LCURLY {| (record_field (%SEPARATOR record_field)* %SEPARATOR?)? |} eRCURLY) -> to_astnode record_field <- ({} '' -> 'RecordFieldType' name eCOLON etypexpr ) -> to_astnode enum_type <- ({} %TENUM -> 'EnumType' - ((%LPAREN eprimtype eRPAREN) / cnil) eLCURLY + ((%LPAREN eprimtype eRPAREN) / cnil) %LCURLY {| eenumfield (%SEPARATOR enumfield)* %SEPARATOR? |} eRCURLY) -> to_astnode enumfield <- ({} '' -> 'EnumFieldType' @@ -506,24 +509,29 @@ local function get_parser(std) ) -> to_astnode arraytable_type <- ( {} 'arraytable' -> 'ArrayTableType' - eLPAREN etypexpr eRPAREN + %LPAREN etypexpr eRPAREN ) -> to_astnode span_type <- ( {} 'span' -> 'SpanType' - eLPAREN etypexpr eRPAREN + %LPAREN etypexpr eRPAREN ) -> to_astnode range_type <- ( {} 'range' -> 'RangeType' - eLPAREN etypexpr eRPAREN + %LPAREN etypexpr eRPAREN ) -> to_astnode array_type <- ( {} 'array' -> 'ArrayType' - eLPAREN etypexpr eCOMMA typexpr_param_expr eRPAREN + %LPAREN etypexpr eCOMMA etype_param_expr eRPAREN ) -> to_astnode pointer_type <- ( {} 'pointer' -> 'PointerType' ((%LPAREN etypexpr eRPAREN) / %SKIP) ) -> to_astnode + generic_type <- ( + {} '' -> 'GenericType' + name %LPAREN {| etype_or_param_expr_list |} eRPAREN + ) -> to_astnode + primtype <- ({} '' -> 'Type' name) -> to_astnode ppexpr <- ({} %LPPEXPR -> 'PreprocessExpr' {expr -> 0} eRPPEXPR) -> to_astnode @@ -574,28 +582,30 @@ local function get_parser(std) -- syntax expected captures with errors grammar:set_pegs([[ - eRPAREN <- %RPAREN / %{UnclosedParenthesis} - eRBRACKET <- %RBRACKET / %{UnclosedBracket} - eRPPEXPR <- %RPPEXPR / %{UnclosedBracket} - eRPPNAME <- %RPPNAME / %{UnclosedParenthesis} - eRCURLY <- %RCURLY / %{UnclosedCurly} - eRANGLE <- %RANGLE / %{UnclosedAngle} - eLPAREN <- %LPAREN / %{ExpectedParenthesis} - eLCURLY <- %LCURLY / %{ExpectedCurly} - eLANGLE <- %LANGLE / %{ExpectedAngle} - eLBRACKET <- %LBRACKET / %{ExpectedBracket} - eCOLON <- %COLON / %{ExpectedColon} - eCOMMA <- %COMMA / %{ExpectedComma} - eEND <- %END / %{ExpectedEnd} - eTHEN <- %THEN / %{ExpectedThen} - eUNTIL <- %UNTIL / %{ExpectedUntil} - eDO <- %DO / %{ExpectedDo} - ename <- name / %{ExpectedName} - eexpr <- expr / %{ExpectedExpression} - etypexpr <- typexpr / %{ExpectedTypeExpression} - ecallargs <- callargs / %{ExpectedCall} - eenumfield <- enumfield / %{ExpectedEnumFieldType} - eprimtype <- primtype / %{ExpectedPrimitiveTypeExpression} + eRPAREN <- %RPAREN / %{UnclosedParenthesis} + eRBRACKET <- %RBRACKET / %{UnclosedBracket} + eRPPEXPR <- %RPPEXPR / %{UnclosedBracket} + eRPPNAME <- %RPPNAME / %{UnclosedParenthesis} + eRCURLY <- %RCURLY / %{UnclosedCurly} + eRANGLE <- %RANGLE / %{UnclosedAngle} + eLPAREN <- %LPAREN / %{ExpectedParenthesis} + eLCURLY <- %LCURLY / %{ExpectedCurly} + eLANGLE <- %LANGLE / %{ExpectedAngle} + eLBRACKET <- %LBRACKET / %{ExpectedBracket} + eCOLON <- %COLON / %{ExpectedColon} + eCOMMA <- %COMMA / %{ExpectedComma} + eEND <- %END / %{ExpectedEnd} + eTHEN <- %THEN / %{ExpectedThen} + eUNTIL <- %UNTIL / %{ExpectedUntil} + eDO <- %DO / %{ExpectedDo} + ename <- name / %{ExpectedName} + eexpr <- expr / %{ExpectedExpression} + etypexpr <- typexpr / %{ExpectedTypeExpression} + etype_or_param_expr <- type_or_param_expr / %{ExpectedExpression} + etype_param_expr <- type_param_expr / %{ExpectedExpression} + ecallargs <- callargs / %{ExpectedCall} + eenumfield <- enumfield / %{ExpectedEnumFieldType} + eprimtype <- primtype / %{ExpectedPrimitiveTypeExpression} ]]) -- compile whole grammar diff --git a/nelua/types.lua b/nelua/types.lua index 1c5fb022..403e2e9a 100644 --- a/nelua/types.lua +++ b/nelua/types.lua @@ -7,6 +7,7 @@ local sstream = require 'nelua.utils.sstream' local metamagic = require 'nelua.utils.metamagic' local config = require 'nelua.configer'.get() local bn = require 'nelua.utils.bn' +local except = require 'nelua.utils.except' local typedefs, primtypes local types = {} @@ -188,6 +189,7 @@ function Type:is_unsigned() return self.unsigned end function Type:is_signed() return self.arithmetic and not self.unsigned end function Type:is_generic_pointer() return self.genericpointer end function Type:is_concept() return self.concept end +function Type:is_generic() return self.generic end function Type.is_pointer_of() return false end function Type.is_array_of() return false end function Type.has_pointer() return false end @@ -1427,6 +1429,42 @@ function ConceptType:is_convertible_from_attr(attr, explicit) return type, err end +-------------------------------------------------------------------------------- +local GenericType = typeclass() +types.GenericType = GenericType +GenericType.nodecl = true +GenericType.nolvalue = true +GenericType.comptime = true +GenericType.unpointable = true +GenericType.generic = true + +function GenericType:_init(func) + Type._init(self, 'generic', 0) + self.func = func +end + +function GenericType:eval_type(params) + local ret + local ok, err = except.trycall(function() + ret = self.func(tabler.unpack(params)) + end) + if err then + return nil, err + end + if traits.is_symbol(ret) then + if not ret.type or not ret.type:is_type() then + ret = nil + err = stringer.pformat("expected a symbol holding a type in generic return, but got something else") + else + ret = ret.value + end + elseif not traits.is_type(ret) then + ret = nil + err = stringer.pformat("expected a type or symbol in generic return, but got '%s'", type(ret)) + end + return ret, err +end + -------------------------------------------------------------------------------- function types.set_typedefs(t) typedefs = t diff --git a/nelua/utils/bn.lua b/nelua/utils/bn.lua index d8769130..d37907ec 100644 --- a/nelua/utils/bn.lua +++ b/nelua/utils/bn.lua @@ -192,6 +192,15 @@ function bn.tointeger(v) end end +local orig_tonumber = bn.tonumber +function bn.tonumber(v) + local vint = v:trunc() + if v == vint then + return tonumber(tostring(vint)) + end + return orig_tonumber(v) +end + -------------------------------------------------------------------------------- -- Utilities diff --git a/spec/02-syntaxdefs_spec.lua b/spec/02-syntaxdefs_spec.lua index 2e6ee505..782408e1 100644 --- a/spec/02-syntaxdefs_spec.lua +++ b/spec/02-syntaxdefs_spec.lua @@ -1417,6 +1417,21 @@ describe("type expression", function() { n.IdDecl{'r', n.RangeType{n.Type{'integer'}}}} }}}) end) + it("generic type", function() + assert.parse_ast(nelua_parser, "local r: somegeneric(integer, 4)", + n.Block{{ + n.VarDecl{'local', { + n.IdDecl{'r', n.GenericType{"somegeneric", { + n.Type{'integer'}, n.Number{"dec", "4"}}}}} + }}}) + assert.parse_ast(nelua_parser, "local r: somegeneric(span(integer), integer*)", + n.Block{{ + n.VarDecl{'local', { + n.IdDecl{'r', n.GenericType{"somegeneric", { + n.SpanType{n.Type{'integer'}}, + n.PointerType{n.Type{"integer"}} + }}}}}}}) + end) it("complex types", function() assert.parse_ast(nelua_parser, "local p: integer*[10]*[10]", n.Block{{ diff --git a/spec/03-typechecker_spec.lua b/spec/03-typechecker_spec.lua index 2340ad65..35708fa8 100644 --- a/spec/03-typechecker_spec.lua +++ b/spec/03-typechecker_spec.lua @@ -1379,4 +1379,52 @@ it("concepts", function() ]], "could not match concept") end) +it("generics", function() + assert.analyze_ast([[ + local myarray = #[generic(function(T, N) return types.ArrayType(nil, T, N) end)]# + local M: integer = 4 + local x = @myarray(integer, (M)) + ]]) + assert.analyze_ast([[ + local int = @integer + local proxy = #[generic(function(T) return int end)]# + local x = @proxy(integer) + ]]) + assert.analyze_error([[ + local proxy = #[generic(function(T) staticerror('my fail') end)]# + local x = @proxy(integer) + ]], 'my fail') + assert.analyze_error([[ + local myarray = #[generic(function(T, N) return types.ArrayType(nil, T, N) end)]# + local M: integer = 4 + local x = @myarray(integer, (M)) + ]], "isn't a compile time value") + assert.analyze_error([[ + local myarray = #[generic(function(T, N) return types.ArrayType(nil, T, N) end)]# + local M: span(integer) = {} + local x = @myarray(integer, (M)) + ]], "is invalid for generics") + assert.analyze_error([[ + local x = @integer(integer) + ]], "doesn't hold a generic type") + assert.analyze_error([[ + local myarray = #[generic(function() end)]# + local i = 1 + local x = @integer(i) + ]], "doesn't hold a generic type") + assert.analyze_error([[ + local myarray = #[generic(function() end)]# + local x = myarray(integer) + ]], "cannot do assertion on generics") + assert.analyze_error([[ + local myarray = #[generic(function() end)]# + local x = @myarray(integer) + ]], "expected a type or symbol in generic return") + assert.analyze_error([[ + local X = 1 + local myarray = #[generic(function() return X end)]# + local x = @myarray(integer) + ]], "expected a symbol holding a type in generic return") +end) + end) diff --git a/spec/05-cgenerator_spec.lua b/spec/05-cgenerator_spec.lua index f49cf38d..09485eed 100644 --- a/spec/05-cgenerator_spec.lua +++ b/spec/05-cgenerator_spec.lua @@ -1959,4 +1959,25 @@ it("concepts", function() ]=]) end) +it("generics", function() + assert.run_c([=[ + local arrayproxy = #[generic(hygienize(memoize(function(T, size) + return types.ArrayType(nil, T, size) + end)))]# + + local intarray = @arrayproxy(integer, 4) + local j: arrayproxy(integer, 4) = {1,2,3,4} + assert(j[0] == 1) + local i = (@arrayproxy(integer, 4)){1,2,3,4} + assert(i[0] == 1) + + local function f(x: arrayproxy(integer, 4)) + assert(x[0] == 1) + end + + f(i) + f(j) + ]=]) end) + +end) \ No newline at end of file