Skip to content

Commit

Permalink
account for vector of tensptrs
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Dec 2, 2018
1 parent 31dd37b commit 67158fc
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 21 deletions.
8 changes: 8 additions & 0 deletions ade/ade.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@
#include "ade/functor.hpp"
#include "ade/ileaf.hpp"
#include "ade/traveler.hpp"

namespace ade
{

/// Vector representation of ade tensor pointers
using TensT = std::vector<TensptrT>;

}
20 changes: 18 additions & 2 deletions age/templates/capi_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import repr

_origtype = 'ade::TensptrT'
_origtypes = 'ade::TensT'
_repltype = 'int64_t'
_repltypes = 'std::vector<int64_t>'

def replace_all(arg):
return arg.replace(_origtype, _repltype)\
.replace(_origtypes, _repltypes)

def affix_apis(apis):
names = [api['name'] for api in apis]
Expand Down Expand Up @@ -41,8 +47,8 @@ def affix_apis(apis):
header.api_decls = ("apis", lambda apis: '\n\n'.join([\
"extern int64_t age_{func} ({args});".format(\
func = api["name"] + affix, args = ', '.join([\
arg.replace(_origtype, _repltype)\
for arg in api["args"]])) for api, affix in affix_apis(apis)]))
replace_all(arg) for arg in api["args"]]))\
for api, affix in affix_apis(apis)]))

# EXPORT
source = repr.FILE_REPR("""#ifdef _GENERATED_CAPI_HPP
Expand Down Expand Up @@ -97,6 +103,12 @@ def affix_apis(apis):
return id;
}}"""

_carr_decl = """
ade::TensT {name}_tens({name}.size());
std::transform({name}.begin(), {name}.end(), {name}_tens.begin(),
[](int64_t id){{ return get_tens(id); }});
"""

def _defn_func(api, affix):
ifunc = api["name"] + affix
vars = [arg.split(' ') for arg in api["args"]]
Expand All @@ -110,6 +122,10 @@ def _defn_func(api, affix):
arg_decls.append('ade::TensptrT {name}_ptr = get_tens({name});'
.format(name=typevar[1]))
args.append(typevar[1] + '_ptr')
elif typevar[0] == _origtypes:
params.append('std::vector<int64_t> {}'.format(typevar[1]))
arg_decls.append(_carr_decl.format(name=typevar[1]))
args.append(typevar[1] + '_tens')
else:
params.append(' '.join(typevar))
args.append(typevar[1])
Expand Down
4 changes: 2 additions & 2 deletions age/templates/grader_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
return ade::Opcode{{"{prod}", {prod}}};
}}
ade::TensptrT grad_rule (size_t code, TensT args, size_t idx) override;
ade::TensptrT grad_rule (size_t code, ade::TensT args, size_t idx) override;
}};
}}
Expand All @@ -52,7 +52,7 @@
namespace age
{{
ade::TensptrT RuleSet::grad_rule (size_t code,TensT args,size_t idx)
ade::TensptrT RuleSet::grad_rule (size_t code, ade::TensT args, size_t idx)
{{
switch (code)
{{
Expand Down
4 changes: 2 additions & 2 deletions age/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@
return ade::Opcode{"MULTIPLICATION", MULTIPLICATION};
}
ade::TensptrT grad_rule (size_t code, TensT args, size_t idx) override;
ade::TensptrT grad_rule (size_t code, ade::TensT args, size_t idx) override;
};
}
Expand All @@ -421,7 +421,7 @@
namespace age
{
ade::TensptrT RuleSet::grad_rule (size_t code,TensT args,size_t idx)
ade::TensptrT RuleSet::grad_rule (size_t code, ade::TensT args, size_t idx)
{
switch (code)
{
Expand Down
4 changes: 2 additions & 2 deletions age/test/grader_dep.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#include <cassert>
#include "age/test/grader_dep.hpp"

ade::TensptrT arms_heavy (size_t idx, age::TensT args)
ade::TensptrT arms_heavy (size_t idx, ade::TensT args)
{
assert(args.size() > 0);
static_cast<MockTensor*>(args[0].get())->scalar_ = idx;
return args[0];
}

ade::TensptrT dj_grad (age::TensT args, size_t idx)
ade::TensptrT dj_grad (ade::TensT args, size_t idx)
{
assert(args.size() > 0);
static_cast<MockTensor*>(args[0].get())->scalar_ = idx + khaled_constant;
Expand Down
4 changes: 2 additions & 2 deletions age/test/grader_dep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct MockTensor : public ade::iLeaf
ade::Shape shape_;
};

ade::TensptrT arms_heavy (size_t idx, age::TensT args);
ade::TensptrT arms_heavy (size_t idx, ade::TensT args);

ade::TensptrT dj_grad (age::TensT args, size_t idx);
ade::TensptrT dj_grad (ade::TensT args, size_t idx);

#endif // MOCK_GRADER_DEP_HPP
5 changes: 5 additions & 0 deletions age/test/mock.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
"name": "vegeta",
"args": ["ade::TensptrT arg1", "uint8_t bardock"],
"out": "freeza(arg1, bardock)"
},
{
"name": "vegeta",
"args": ["uint8_t bardock", "ade::TensT arrs"],
"out": "freeza(arrs[0], bardock)"
}
]
}
10 changes: 10 additions & 0 deletions age/test/test_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ TEST(AGE, Api)
EXPECT_EQ(2, planet->scalar_);
EXPECT_EQ(31, vshape.n_elems());
EXPECT_EQ(31, vshape.at(0));

ade::TensptrT vegetable2(
age::vegeta(2, {ade::TensptrT(
new MockTensor(1, ade::Shape({1, 1, 31})))}));
MockTensor* planet2 = dynamic_cast<MockTensor*>(vegetable2.get());
EXPECT_NE(nullptr, planet2);
ade::Shape vshape2 = planet2->shape();
EXPECT_EQ(2, planet2->scalar_);
EXPECT_EQ(31, vshape2.n_elems());
EXPECT_EQ(31, vshape2.at(0));
}


Expand Down
11 changes: 10 additions & 1 deletion age/test/test_capi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,23 @@ TEST(AGE, CApi)
EXPECT_EQ(16, shape.at(0));

int64_t var = register_tens(new MockTensor(1, ade::Shape({1, 1, 31})));
int64_t vegetable = age_vegeta(var, 2);
int64_t vegetable = age_vegeta_1(var, 2);
MockTensor* planet = dynamic_cast<MockTensor*>(
get_tens(vegetable).get());
EXPECT_NE(nullptr, planet);
ade::Shape vshape = planet->shape();
EXPECT_EQ(2, planet->scalar_);
EXPECT_EQ(31, vshape.n_elems());
EXPECT_EQ(31, vshape.at(0));

int64_t vegetable2 = age_vegeta(2, {var});
MockTensor* planet2 = dynamic_cast<MockTensor*>(
get_tens(vegetable2).get());
EXPECT_NE(nullptr, planet2);
ade::Shape vshape2 = planet2->shape();
EXPECT_EQ(2, planet2->scalar_);
EXPECT_EQ(31, vshape2.n_elems());
EXPECT_EQ(31, vshape2.at(0));
}


Expand Down
7 changes: 2 additions & 5 deletions bwd/grader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
namespace age
{

/// Vector representation of tensor pointers
using TensT = std::vector<ade::TensptrT>;

/// Ruleset used by a Grader traveler to derive equations
struct iRuleSet
{
Expand All @@ -35,7 +32,7 @@ struct iRuleSet

/// Return chain rule of operation with respect to argument at idx
/// specified by code given args
virtual ade::TensptrT grad_rule (size_t code, TensT args, size_t idx) = 0;
virtual ade::TensptrT grad_rule (size_t code, ade::TensT args, size_t idx) = 0;
};

/// Traveler to obtain derivative of accepted node with respect to target
Expand Down Expand Up @@ -84,7 +81,7 @@ struct Grader final : public ade::iTraveler
};

/// Return ArgsT with each tensor in TensT attached to identity mapper
ade::ArgsT to_args (TensT tens);
ade::ArgsT to_args (ade::TensT tens);

}

Expand Down
8 changes: 4 additions & 4 deletions bwd/src/grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ void Grader::visit (ade::iFunctor* func)
return stat.graphsize_[a] > stat.graphsize_[b];
});

std::unordered_map<const ade::iTensor*,TensT> grads = {
std::unordered_map<const ade::iTensor*,ade::TensT> grads = {
{func, {rules_->data(1, func->shape())}},
};
for (ade::iFunctor* parent : parents)
{
ade::Opcode opcode = parent->get_opcode();
TensT& gargs = grads[parent];
ade::TensT& gargs = grads[parent];
ade::TensptrT bwd(gargs.size() > 1 ? gargs[0] :
ade::TensptrT(ade::Functor::get(rules_->sum_opcode(), to_args(gargs))));

Expand All @@ -60,7 +60,7 @@ void Grader::visit (ade::iFunctor* func)
ordered.sort();
for (size_t i : ordered)
{
TensT args;
ade::TensT args;
ade::MappedTensor& child = children[i];
ade::CoordPtrT mapper(child.mapper_->reverse());
for (size_t j = 0; j < nchildren; ++j)
Expand Down Expand Up @@ -95,7 +95,7 @@ void Grader::visit (ade::iFunctor* func)
to_args(grads[target_]))));
}

ade::ArgsT to_args (TensT tens)
ade::ArgsT to_args (ade::TensT tens)
{
ade::ArgsT args;
std::transform(tens.begin(), tens.end(), std::back_inserter(args),
Expand Down
2 changes: 1 addition & 1 deletion bwd/test/test_grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct MockRuleSet final : public age::iRuleSet
return ade::Opcode{"*", 1};
}

ade::TensptrT grad_rule (size_t code, age::TensT args, size_t idx) override
ade::TensptrT grad_rule (size_t code, ade::TensT args, size_t idx) override
{
// grad of sum is prod and grad of prod is sum
if (code)
Expand Down

0 comments on commit 67158fc

Please sign in to comment.