Skip to content

Commit

Permalink
let age config define C type conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Dec 3, 2018
1 parent b258f5f commit 476c240
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 30 deletions.
12 changes: 11 additions & 1 deletion age/README_AGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,17 @@ The configuration script must be a json file in the following format:
"apis": [
{
"name": "< function name >",
"args": ["< arg type > < arg name >", ...],
"args": [{
"dtype": "< arg type >",
"name": "< arg name >",
"c": { // this is optional
"args": [{
"dtype": "< c arg type >",
"name": "< c arg name >"
}, ...],
"convert": "< combine args to c++ arg >"
}
}, ...],
"out": "< output signature >"
},
...
Expand Down
12 changes: 6 additions & 6 deletions age/templates/api_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
""")

header.api_decls = ("apis", lambda apis: '\n\n'.join(["ade::TensptrT {api} ({args});".format(\
api = api["name"], args = ', '.join(api["args"])) for api in apis]))
api = api["name"], args = ', '.join([arg['dtype'] + ' ' + arg['name']\
for arg in api["args"]])) for api in apis]))

# EXPORT
source = repr.FILE_REPR("""#ifdef _GENERATED_API_HPP
Expand All @@ -33,12 +34,10 @@
""")

def _nullcheck(args):
vars = [arg.split(' ') for arg in args]
tens = list(filter(lambda vpair: vpair[0] == 'ade::TensptrT',\
[(var[0], var[-1]) for var in vars]))
tens = list(filter(lambda arg: arg['dtype'] == 'ade::TensptrT', args))
if len(tens) == 0:
return "false"
varnames = [ten[-1] for ten in tens]
varnames = [ten['name'] for ten in tens]
return " || ".join([varname + " == nullptr" for varname in varnames])

source.apis = ("apis", lambda apis: '\n\n'.join(["""ade::TensptrT {api} ({args})
Expand All @@ -50,6 +49,7 @@ def _nullcheck(args):
return {retval};
}}""".format(
api = api["name"],
args = ', '.join(api["args"]),
args = ', '.join([arg['dtype'] + ' ' + arg['name']\
for arg in api["args"]]),
null_check = _nullcheck(api["args"]),
retval = api["out"]) for api in apis]))
28 changes: 15 additions & 13 deletions age/templates/capi_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ def affix_apis(apis):
affixes.append((api, affix))
return affixes

def typesplit(args):
args = [arg.split(' ') for arg in args]
return [(arg[0], arg[-1]) for arg in args]

# EXPORT
header = repr.FILE_REPR("""#ifndef _GENERATED_CAPI_HPP
#define _GENERATED_CAPI_HPP
Expand All @@ -47,10 +43,14 @@ def typesplit(args):
_cfunc_sign_fmt = "int64_t age_{ifunc} ({params})"

def _decl_func(api, affix):
args = typesplit(api["args"])
params = []
for dtype, argname in args:
if dtype == _origtype:
for arg in api["args"]:
dtype = arg["dtype"]
argname = arg["name"]
if 'c' in arg:
for cv in arg['c']['args']:
params.append(cv['dtype'] + ' ' + cv['name'])
elif dtype == _origtype:
params.append(_repltype + ' ' + argname)
elif dtype == _origarrtype:
params.append('int64_t* ' + argname)
Expand Down Expand Up @@ -108,20 +108,22 @@ def _decl_func(api, affix):
#endif
""")

_cfunc_bloc_fmt = """
{{{arg_decls}
_cfunc_bloc_fmt = """{{{arg_decls}
auto ptr = age::{func}({params});
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
}}"""

def _defn_func(api, affix):
args = typesplit(api["args"])
decls = []
params = []
for dtype, argname in args:
if dtype == _origtype:
for arg in api["args"]:
dtype = arg["dtype"]
argname = arg["name"]
if 'c' in arg:
params.append(arg['c']['convert'])
elif dtype == _origtype:
decls.append('ade::TensptrT {name}_ptr = get_tens({name});'
.format(name=argname))
params.append(argname + '_ptr')
Expand All @@ -136,7 +138,7 @@ def _defn_func(api, affix):
arg_decls = '\n '.join(decls)
if len(arg_decls) > 0:
arg_decls = '\n ' + arg_decls
return _decl_func(api, affix) + _cfunc_bloc_fmt.format(
return _decl_func(api, affix) + '\n' + _cfunc_bloc_fmt.format(
arg_decls = arg_decls,
func = api["name"],
params = ', '.join(params))
Expand Down
44 changes: 37 additions & 7 deletions age/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,40 @@

api_fields = {"apis": [
{"name": "func1", "args": [], "out": "bar1()"},
{"name": "func2", "args": ["ade::TensptrT arg", "Arg arg1"], "out": "bar2()"},
{"name": "func3", "args": [
"ade::TensptrT arg", "Arg arg1", "ade::TensptrT arg2"], "out": "bar3()"},
{"name": "func1", "args": ["ade::TensT arg", "Arg arg1"], "out": "bar4()"}
{"name": "func2", "args": [{
"dtype": "ade::TensptrT",
"name": "arg"
}, {
"dtype": "Arg",
"name": "arg1",
"c": {
"args": [{
"dtype": "int",
"name": "n_arg1"
}, {
"dtype": "float",
"name": "arg1_f"
}],
"convert": "Arg(arg1_f, n_arg1)"
}
}], "out": "bar2()"},
{"name": "func3", "args": [{
"dtype": "ade::TensptrT",
"name": "arg"
}, {
"dtype": "Arg",
"name": "arg1"
}, {
"dtype": "ade::TensptrT",
"name": "arg2"
}], "out": "bar3()"},
{"name": "func1", "args": [{
"dtype": "ade::TensT",
"name": "arg"
}, {
"dtype": "Arg",
"name": "arg1"
}], "out": "bar4()"}
]}

codes_fields = {
Expand Down Expand Up @@ -132,7 +162,7 @@
extern int64_t age_func1_1 ();
extern int64_t age_func2 (int64_t arg, Arg arg1);
extern int64_t age_func2 (int64_t arg, int n_arg1, float arg1_f);
extern int64_t age_func3 (int64_t arg, Arg arg1, int64_t arg2);
Expand Down Expand Up @@ -188,10 +218,10 @@
return id;
}
int64_t age_func2 (int64_t arg, Arg arg1)
int64_t age_func2 (int64_t arg, int n_arg1, float arg1_f)
{
ade::TensptrT arg_ptr = get_tens(arg);
auto ptr = age::func2(arg_ptr, arg1);
auto ptr = age::func2(arg_ptr, Arg(arg1_f, n_arg1));
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
Expand Down
21 changes: 18 additions & 3 deletions age/test/mock.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,32 @@
"apis": [
{
"name": "goku",
"args": ["size_t bardock"],
"args": [{
"dtype": "size_t",
"name": "bardock"
}],
"out": "cooler(bardock)"
},
{
"name": "vegeta",
"args": ["ade::TensptrT arg1", "uint8_t bardock"],
"args": [{
"dtype": "ade::TensptrT",
"name": "arg1"
}, {
"dtype": "uint8_t",
"name": "bardock"
}],
"out": "freeza(arg1, bardock)"
},
{
"name": "vegeta",
"args": ["uint8_t bardock", "ade::TensT arrs"],
"args": [{
"dtype": "uint8_t",
"name": "bardock"
}, {
"dtype": "ade::TensT",
"name": "arrs"
}],
"out": "freeza(arrs[0], bardock)"
}
]
Expand Down

0 comments on commit 476c240

Please sign in to comment.