Skip to content

Commit

Permalink
expose capi tens map
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Dec 1, 2018
1 parent 6f5720b commit 5bf8773
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 31 deletions.
32 changes: 18 additions & 14 deletions age/templates/capi_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
header = repr.FILE_REPR("""#ifndef _GENERATED_CAPI_HPP
#define _GENERATED_CAPI_HPP
int64_t malloc_tens (void* ptr);
int64_t register_tens (ade::iTensor* ptr);
void* get_ptr (int64_t id);
int64_t register_tens (ade::TensptrT& ptr);
ade::TensptrT get_tens (int64_t id);
extern void free_tens (int64_t id);
Expand All @@ -33,26 +35,28 @@
static std::unordered_map<int64_t,ade::TensptrT> tens;
inline ade::TensptrT get_tens (int64_t id)
int64_t register_tens (ade::iTensor* ptr)
{{
auto it = tens.find(id);
if (tens.end() == it)
{{
return ade::TensptrT(nullptr);
}}
return it->second;
int64_t id = (int64_t) ptr;
tens.emplace(id, ade::TensptrT(ptr));
return id;
}}
int64_t malloc_tens (void* ptr)
int64_t register_tens (ade::TensptrT& ptr)
{{
int64_t id = (int64_t) ptr;
tens.emplace(id, ade::TensptrT(static_cast<ade::iTensor*>(ptr)));
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
}}
void* get_ptr (int64_t id)
ade::TensptrT get_tens (int64_t id)
{{
return get_tens(id).get();
auto it = tens.find(id);
if (tens.end() == it)
{{
return ade::TensptrT(nullptr);
}}
return it->second;
}}
void free_tens (int64_t id)
Expand Down
32 changes: 18 additions & 14 deletions age/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@
capi_header = """#ifndef _GENERATED_CAPI_HPP
#define _GENERATED_CAPI_HPP
int64_t malloc_tens (void* ptr);
int64_t register_tens (ade::iTensor* ptr);
void* get_ptr (int64_t id);
int64_t register_tens (ade::TensptrT& ptr);
ade::TensptrT get_tens (int64_t id);
extern void free_tens (int64_t id);
Expand All @@ -128,26 +130,28 @@
static std::unordered_map<int64_t,ade::TensptrT> tens;
inline ade::TensptrT get_tens (int64_t id)
int64_t register_tens (ade::iTensor* ptr)
{
auto it = tens.find(id);
if (tens.end() == it)
{
return ade::TensptrT(nullptr);
}
return it->second;
int64_t id = (int64_t) ptr;
tens.emplace(id, ade::TensptrT(ptr));
return id;
}
int64_t malloc_tens (void* ptr)
int64_t register_tens (ade::TensptrT& ptr)
{
int64_t id = (int64_t) ptr;
tens.emplace(id, ade::TensptrT(static_cast<ade::iTensor*>(ptr)));
int64_t id = (int64_t) ptr.get();
tens.emplace(id, ptr);
return id;
}
void* get_ptr (int64_t id)
ade::TensptrT get_tens (int64_t id)
{
return get_tens(id).get();
auto it = tens.find(id);
if (tens.end() == it)
{
return ade::TensptrT(nullptr);
}
return it->second;
}
void free_tens (int64_t id)
Expand Down
6 changes: 3 additions & 3 deletions age/test/test_capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ TEST(AGE, CApi)
// except inputs and output types are different
int64_t carrot = age_goku(16);
MockTensor* kakarot = dynamic_cast<MockTensor*>(
static_cast<ade::iTensor*>(get_ptr(carrot)));
get_tens(carrot).get());
EXPECT_NE(nullptr, kakarot);
ade::Shape shape = kakarot->shape();
EXPECT_EQ(16, kakarot->scalar_);
EXPECT_EQ(16, shape.n_elems());
EXPECT_EQ(16, shape.at(0));

int64_t var = malloc_tens(new MockTensor(1, ade::Shape({1, 1, 31})));
int64_t var = register_tens(new MockTensor(1, ade::Shape({1, 1, 31})));
int64_t vegetable = age_vegeta(var, 2);
MockTensor* planet = dynamic_cast<MockTensor*>(
static_cast<ade::iTensor*>(get_ptr(vegetable)));
get_tens(vegetable).get());
EXPECT_NE(nullptr, planet);
ade::Shape vshape = planet->shape();
EXPECT_EQ(2, planet->scalar_);
Expand Down

0 comments on commit 5bf8773

Please sign in to comment.