Skip to content

Commit

Permalink
remove vector from TensT c api
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Dec 2, 2018
1 parent 67158fc commit b258f5f
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 46 deletions.
1 change: 1 addition & 0 deletions age/agen.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def make_dir(fields, includes, includepath, gen_capi):
capi_source = 'c' + api_source
capi_hdr_path = os.path.join(includepath, capi_header)
capi_source_include = [
'<algorithm>',
'<unordered_map>',
'"' + api_hdr_path + '"',
'"' + capi_hdr_path + '"',
Expand Down
86 changes: 44 additions & 42 deletions age/templates/capi_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
import repr

_origtype = 'ade::TensptrT'
_origtypes = 'ade::TensT'
_repltype = 'int64_t'
_repltypes = 'std::vector<int64_t>'

def replace_all(arg):
return arg.replace(_origtype, _repltype)\
.replace(_origtypes, _repltypes)
_origarrtype = 'ade::TensT'

def affix_apis(apis):
names = [api['name'] for api in apis]
Expand All @@ -25,6 +21,10 @@ def affix_apis(apis):
affixes.append((api, affix))
return affixes

def typesplit(args):
args = [arg.split(' ') for arg in args]
return [(arg[0], arg[-1]) for arg in args]

# EXPORT
header = repr.FILE_REPR("""#ifndef _GENERATED_CAPI_HPP
#define _GENERATED_CAPI_HPP
Expand All @@ -44,11 +44,24 @@ def affix_apis(apis):
#endif // _GENERATED_CAPI_HPP
""")

_cfunc_sign_fmt = "int64_t age_{ifunc} ({params})"

def _decl_func(api, affix):
args = typesplit(api["args"])
params = []
for dtype, argname in args:
if dtype == _origtype:
params.append(_repltype + ' ' + argname)
elif dtype == _origarrtype:
params.append('int64_t* ' + argname)
params.append('uint64_t n_' + argname)
else:
params.append(dtype + ' ' + argname)
return _cfunc_sign_fmt.format(ifunc = api["name"] + affix,
params = ', '.join(params))

header.api_decls = ("apis", lambda apis: '\n\n'.join([\
"extern int64_t age_{func} ({args});".format(\
func = api["name"] + affix, args = ', '.join([\
replace_all(arg) for arg in api["args"]]))\
for api, affix in affix_apis(apis)]))
'extern ' + _decl_func(api, affix) + ';' for api, affix in affix_apis(apis)]))

# EXPORT
source = repr.FILE_REPR("""#ifdef _GENERATED_CAPI_HPP
Expand Down Expand Up @@ -95,49 +108,38 @@ def affix_apis(apis):
#endif
""")

_cfunc_fmt = """int64_t age_{ifunc} ({params})
{{
{arg_decls}auto ptr = age::{func}({retargs});
_cfunc_bloc_fmt = """
{{{arg_decls}
auto ptr = age::{func}({params});
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
}}"""

_carr_decl = """
ade::TensT {name}_tens({name}.size());
std::transform({name}.begin(), {name}.end(), {name}_tens.begin(),
[](int64_t id){{ return get_tens(id); }});
"""

def _defn_func(api, affix):
ifunc = api["name"] + affix
vars = [arg.split(' ') for arg in api["args"]]
typevars = [(var[0], var[-1]) for var in vars]
args = typesplit(api["args"])
decls = []
params = []
arg_decls = []
args = []
for typevar in typevars:
if typevar[0] == _origtype:
params.append('int64_t {}'.format(typevar[1]))
arg_decls.append('ade::TensptrT {name}_ptr = get_tens({name});'
.format(name=typevar[1]))
args.append(typevar[1] + '_ptr')
elif typevar[0] == _origtypes:
params.append('std::vector<int64_t> {}'.format(typevar[1]))
arg_decls.append(_carr_decl.format(name=typevar[1]))
args.append(typevar[1] + '_tens')
for dtype, argname in args:
if dtype == _origtype:
decls.append('ade::TensptrT {name}_ptr = get_tens({name});'
.format(name=argname))
params.append(argname + '_ptr')
elif dtype == _origarrtype:
decls.append('ade::TensT {name}_tens(n_{name});'.format(name=argname))
decls.append('std::transform({name}, {name} + n_{name}, {name}_tens.begin(),'.\
format(name=argname))
decls.append(' [](int64_t id){ return get_tens(id); });')
params.append(argname + '_tens')
else:
params.append(' '.join(typevar))
args.append(typevar[1])
arg_decls_str = '\n '.join(arg_decls)
params.append(argname)
arg_decls = '\n '.join(decls)
if len(arg_decls) > 0:
arg_decls_str = arg_decls_str + '\n '
return _cfunc_fmt.format(
ifunc = ifunc,
arg_decls = '\n ' + arg_decls
return _decl_func(api, affix) + _cfunc_bloc_fmt.format(
arg_decls = arg_decls,
func = api["name"],
params = ', '.join(params),
arg_decls = arg_decls_str,
retargs = ', '.join(args))
params = ', '.join(params))

source.apis = ("apis", lambda apis: '\n\n'.join([_defn_func(api, affix)\
for api, affix in affix_apis(apis)]))
32 changes: 29 additions & 3 deletions age/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
api_fields = {"apis": [
{"name": "func1", "args": [], "out": "bar1()"},
{"name": "func2", "args": ["ade::TensptrT arg", "Arg arg1"], "out": "bar2()"},
{"name": "func3", "args": ["ade::TensptrT arg", "Arg arg1", "ade::TensptrT arg2"], "out": "bar3()"}
{"name": "func3", "args": [
"ade::TensptrT arg", "Arg arg1", "ade::TensptrT arg2"], "out": "bar3()"},
{"name": "func1", "args": ["ade::TensT arg", "Arg arg1"], "out": "bar4()"}
]}

codes_fields = {
Expand Down Expand Up @@ -62,6 +64,8 @@
ade::TensptrT func3 (ade::TensptrT arg, Arg arg1, ade::TensptrT arg2);
ade::TensptrT func1 (ade::TensT arg, Arg arg1);
}
#endif // _GENERATED_API_HPP
Expand Down Expand Up @@ -99,6 +103,15 @@
return bar3();
}
ade::TensptrT func1 (ade::TensT arg, Arg arg1)
{
if (false)
{
logs::fatal("cannot func1 with a null argument");
}
return bar4();
}
}
#endif
Expand All @@ -117,12 +130,14 @@
extern void get_shape (int outshape[8], int64_t tens);
extern int64_t age_func1 ();
extern int64_t age_func1_1 ();
extern int64_t age_func2 (int64_t arg, Arg arg1);
extern int64_t age_func3 (int64_t arg, Arg arg1, int64_t arg2);
extern int64_t age_func1 (int64_t* arg, uint64_t n_arg, Arg arg1);
#endif // _GENERATED_CAPI_HPP
"""

Expand Down Expand Up @@ -165,7 +180,7 @@
std::copy(shape.begin(), shape.end(), outshape);
}
int64_t age_func1 ()
int64_t age_func1_1 ()
{
auto ptr = age::func1();
int64_t id = (int64_t) ptr.get();
Expand All @@ -192,6 +207,17 @@
return id;
}
int64_t age_func1 (int64_t* arg, uint64_t n_arg, Arg arg1)
{
ade::TensT arg_tens(n_arg);
std::transform(arg, arg + n_arg, arg_tens.begin(),
[](int64_t id){ return get_tens(id); });
auto ptr = age::func1(arg_tens, arg1);
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
}
#endif
"""

Expand Down
3 changes: 2 additions & 1 deletion age/test/test_capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ TEST(AGE, CApi)
EXPECT_EQ(31, vshape.n_elems());
EXPECT_EQ(31, vshape.at(0));

int64_t vegetable2 = age_vegeta(2, {var});
int64_t varr[1] = {var};
int64_t vegetable2 = age_vegeta(2, varr, 1);
MockTensor* planet2 = dynamic_cast<MockTensor*>(
get_tens(vegetable2).get());
EXPECT_NE(nullptr, planet2);
Expand Down

0 comments on commit b258f5f

Please sign in to comment.