Skip to content

Commit

Permalink
move over bwd rehaul from cortenn
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jan 4, 2019
1 parent 19fd12d commit 0eb9733
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 114 deletions.
13 changes: 4 additions & 9 deletions age/templates/grader_tmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,8 @@ def sortkey(dic):
return ade::Opcode{{"{sum}", {sum}}};
}}
ade::Opcode prod_opcode (void) override
{{
return ade::Opcode{{"{prod}", {prod}}};
}}
ade::TensptrT grad_rule (ade::iFunctor* fwd, ade::TensT args, size_t idx) override;
ade::TensptrT chain_rule (ade::iFunctor* fwd,
ade::MappedTensor bwd, ade::TensT args, size_t idx) override;
}};
}}
Expand All @@ -52,16 +48,15 @@ def sortkey(dic):

header.sum = ('data.sum', lambda sum: sum)

header.prod = ('data.prod', lambda prod: prod)

# EXPORT
source = template.AGE_FILE(FILENAME, template.SOURCE_EXT,
'''#ifdef _GENERATED_GRADER_HPP
namespace age
{{
ade::TensptrT RuleSet::grad_rule (ade::iFunctor* fwd, ade::TensT args, size_t idx)
ade::TensptrT RuleSet::chain_rule (ade::iFunctor* fwd,
ade::MappedTensor bwd, ade::TensT args, size_t idx)
{{
switch (fwd->get_opcode().code_)
{{
Expand Down
1 change: 0 additions & 1 deletion age/test/mock.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
},
"data": {
"sum": "EMINEM",
"prod": "KHALED",
"data_out": "SweetPotato&",
"data_in": "Pomegranate&",
"scalarize": "ade::LeafptrT(new MockTensor(scalar, shape))"
Expand Down
11 changes: 4 additions & 7 deletions age/test/ptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,8 @@
return ade::Opcode{"ADDITION", ADDITION};
}
ade::Opcode prod_opcode (void) override
{
return ade::Opcode{"MULTIPLICATION", MULTIPLICATION};
}
ade::TensptrT grad_rule (ade::iFunctor* fwd, ade::TensT args, size_t idx) override;
ade::TensptrT chain_rule (ade::iFunctor* fwd,
ade::MappedTensor bwd, ade::TensT args, size_t idx) override;
};
}
Expand All @@ -368,7 +364,8 @@
namespace age
{
ade::TensptrT RuleSet::grad_rule (ade::iFunctor* fwd, ade::TensT args, size_t idx)
ade::TensptrT RuleSet::chain_rule (ade::iFunctor* fwd,
ade::MappedTensor bwd, ade::TensT args, size_t idx)
{
switch (fwd->get_opcode().code_)
{
Expand Down
15 changes: 4 additions & 11 deletions age/test/test_grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ TEST(AGE, RulesSum)
}


TEST(AGE, RulesProd)
{
age::RuleSet rule;
auto code = rule.prod_opcode();
EXPECT_STREQ("KHALED", code.name_.c_str());
EXPECT_EQ(age::KHALED, code.code_);
}


TEST(AGE, GraderEminem)
{
age::RuleSet rule;
Expand All @@ -49,7 +40,8 @@ TEST(AGE, GraderEminem)
ade::Functor* fwd = ade::Functor::get(ade::Opcode{"EMINEM", age::EMINEM},
{ade::identity_map(arg)});
size_t idx = 42;
rule.grad_rule(fwd, {arg}, idx);
// bwd is never used so use whatever
rule.chain_rule(fwd, ade::identity_map(arg), {arg}, idx);
EXPECT_EQ(idx, mock->scalar_);
delete fwd;
}
Expand All @@ -63,7 +55,8 @@ TEST(AGE, GraderKhaled)
ade::Functor* fwd = ade::Functor::get(ade::Opcode{"KHALED", age::KHALED},
{ade::identity_map(arg)});
size_t idx = 63;
rule.grad_rule(fwd, {arg}, idx);
// bwd is never used so use whatever
rule.chain_rule(fwd, ade::identity_map(arg), {arg}, idx);
EXPECT_EQ(idx + khaled_constant, mock->scalar_);
delete fwd;
}
Expand Down
12 changes: 6 additions & 6 deletions bwd/grader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ struct iRuleSet
/// Return opcode representing nnary sum
virtual ade::Opcode sum_opcode (void) = 0;

/// Return opcode representing binary multiplication
virtual ade::Opcode prod_opcode (void) = 0;

/// Return chain rule of operation with respect to argument at idx
/// specified by code given args
virtual ade::TensptrT grad_rule (ade::iFunctor* fwd, ade::TensT args, size_t idx) = 0;
/// Return d(fwd)/d(x) given:
/// bwd = d(args[idx])/d(x)
/// Generally,
/// d(fwd)/d(x) = rule(fwd,args,idx) * reduction_consolidation(bwd)
virtual ade::TensptrT chain_rule (ade::iFunctor* fwd,
ade::MappedTensor bwd, ade::TensT args, size_t idx) = 0;
};

/// Traveler to obtain derivative of accepted node with respect to target
Expand Down
26 changes: 6 additions & 20 deletions bwd/src/grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace age
void Grader::visit (ade::iFunctor* func)
{
const ade::Opcode sum = rules_->sum_opcode();
const ade::Opcode prod = rules_->prod_opcode();
if (func == target_)
{
derivatives_.emplace(func, rules_->data(1, target_->shape()));
Expand Down Expand Up @@ -50,8 +49,7 @@ void Grader::visit (ade::iFunctor* func)
{
ade::TensT& gargs = grads[parent];
ade::TensptrT bwd = gargs.size() > 1 ? ade::TensptrT(
ade::Functor::get(sum, to_args(gargs))) :
gargs[0];
ade::Functor::get(sum, to_args(gargs))) : gargs[0];

auto& grad_indices = pathmap[parent];
ade::ArgsT children = parent->get_children();
Expand Down Expand Up @@ -81,31 +79,19 @@ void Grader::visit (ade::iFunctor* func)
args.push_back(ade::TensptrT(ade::Functor::get(sum, {
ade::MappedTensor(
ade::TensptrT(ade::Functor::get(sum, {kid})),
revshaper,
revmapper,
revcoorder)
revshaper, revmapper, revcoorder)
})));
}
}
// pass down forward-gradient pair
ade::TensptrT grad(rules_->grad_rule(parent, args, i));

grads[child.get_tensor().get()].push_back(ade::TensptrT(
ade::Functor::get(prod, {
ade::identity_map(grad),
ade::identity_map(ade::TensptrT(
ade::Functor::get(sum, {
ade::MappedTensor(bwd, revshaper,
revmapper, revcoorder)
})
)),
})));
ade::MappedTensor lhs(bwd, revshaper, revmapper, revcoorder);

grads[child.get_tensor().get()].push_back(rules_->chain_rule(parent, lhs, args, i));
}
}
auto finalgargs = grads[target_];
derivatives_.emplace(func, finalgargs.size() > 1 ? ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(),
to_args(finalgargs))) : finalgargs[0]);
ade::Functor::get(sum, to_args(finalgargs))) : finalgargs[0]);
}

ade::ArgsT to_args (ade::TensT tens)
Expand Down
Loading

0 comments on commit 0eb9733

Please sign in to comment.