Skip to content
This repository has been archived by the owner on Feb 29, 2020. It is now read-only.

Commit

Permalink
fix reduce_prod
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jan 4, 2019
1 parent a27d0fc commit 0b53765
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 37 deletions.
4 changes: 1 addition & 3 deletions bwd/src/grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ void Grader::visit (ade::iFunctor* func)
args.push_back(ade::TensptrT(ade::Functor::get(sum, {
ade::MappedTensor(
ade::TensptrT(ade::Functor::get(sum, {kid})),
revshaper,
revmapper,
revcoorder)
revshaper, revmapper, revcoorder)
})));
}
}
Expand Down
16 changes: 1 addition & 15 deletions llo/cfg/llo.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
},
"PROD": {
"operation": "llo::mul((T*)out,shape,llo::to_refs<T>(in))",
"derivative": "mul(llo::grad_prod(idx,args),ade::TensptrT(ade::Functor::get(ade::Opcode{\"PROD\",PROD},{bwd})))"
"derivative": "mul(llo::grad_prod(fwd,idx,args),ade::TensptrT(ade::Functor::get(ade::Opcode{\"PROD\",PROD},{bwd})))"
},
"DIV": {
"operation": "llo::div((T*)out,shape,llo::to_ref<T>(in[0]),llo::to_ref<T>(in[1]))",
Expand Down Expand Up @@ -458,20 +458,6 @@
}],
"out": "ade::TensptrT(ade::Functor::get(ade::Opcode{\"SUM\",SUM},{ade::extend_map(arg1,arg2,arg3)}))"
},
{
"name": "pextend",
"args": [{
"dtype": "ade::TensptrT",
"name": "arg1"
}, {
"dtype": "uint8_t",
"name": "arg2"
}, {
"dtype": "std::vector<uint8_t>",
"name": "arg3"
}],
"out": "ade::TensptrT(ade::Functor::get(ade::Opcode{\"PROD\",PROD},{ade::extend_map(arg1,arg2,arg3)}))"
},
{
"name": "reduce_sum",
"args": [{
Expand Down
2 changes: 1 addition & 1 deletion llo/helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ade::TensptrT mtens_mul (ade::TensptrT lhs, ade::MappedTensor rhs);

/// Return the gradient for prod operation assuming the target derived wrt is
/// index gradidx and arguments are tens
ade::TensptrT grad_prod (size_t gradidx, ade::TensT tens);
ade::TensptrT grad_prod (ade::iFunctor* fwd, size_t gradidx, ade::TensT tens);

/// Return the gradient for min operation assuming the target derived wrt is
/// index gradidx and arguments are tens
Expand Down
22 changes: 14 additions & 8 deletions llo/src/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ ade::TensptrT mtens_mul (ade::TensptrT lhs, ade::MappedTensor rhs)
}));
}

ade::TensptrT grad_prod (size_t gradidx, ade::TensT tens)
ade::TensptrT grad_prod (ade::iFunctor* fwd, size_t gradidx, ade::TensT tens)
{
ade::Shape shape = tens[gradidx]->shape();
tens.erase(tens.begin() + gradidx);
if (tens.size() > 0)
{
return age::prod(tens);
}
return llo::data(1, shape);
auto fwd_children = fwd->get_children();
ade::TensptrT fwd_cpy(ade::Functor::get(
fwd->get_opcode(), fwd_children));

auto& fwd_child = fwd_children[gradidx];
ade::MappedTensor fwd_mapped(fwd_cpy,
ade::CoordptrT(fwd_child.get_shaper()->reverse()),
!fwd_child.map_io(), fwd_child.get_coorder());

ade::TensptrT fwd_extended(
ade::Functor::get(ade::Opcode{"SUM", age::SUM}, {fwd_mapped}));

return age::div(fwd_extended, tens[gradidx]);
}

ade::TensptrT grad_min (ade::iFunctor* fwd, size_t gradidx, ade::TensT tens)
Expand Down
13 changes: 3 additions & 10 deletions llo/test/ptest.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,6 @@ def test_extend(self):
expected_out = np.array(list(data) * 3).reshape([3, 2])
var = llo.variable(data, 'var')

# extend's derivative equates to reduce_sum
out = age.extend(var, 1, [3])
fout = llo.evaluate(out)
self._array_eq(expected_out, fout)
Expand All @@ -376,18 +375,12 @@ def test_extend(self):
der = llo.evaluate(ex)
self._array_eq(np.array([3, 3]), der)

# pextend's derivative equates to reduce_prod
out2 = age.pextend(var, 1, [3])
fout2 = llo.evaluate(out2)
self._array_eq(expected_out, fout2)

ex2 = llo.derive(out2, var)
der2 = llo.evaluate(ex2)
self._array_eq(np.array([1, 1]), der2)

def test_rsum(self):
self._common_reduce(age.reduce_sum0, age.reduce_sum, tf.reduce_sum)

def test_rprod(self):
self._common_reduce(age.reduce_prod0, age.reduce_prod, tf.reduce_prod)

def test_rmin(self):
self._common_reduce(age.reduce_min0, age.reduce_min, tf.reduce_min)

Expand Down

0 comments on commit 0b53765

Please sign in to comment.