Skip to content

Commit

Permalink
avoid optimizing during gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jan 2, 2019
1 parent 18e97c2 commit fc45269
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 139 deletions.
32 changes: 19 additions & 13 deletions ade/cmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,25 @@ struct MappedTensor final
}

/// Return shape of tensor filtered through coordinate mapper
Shape shape (void) const;
Shape shape (void) const
{
ade::Shape shape = tensor_->shape();
CoordT out;
CoordT in;
std::copy(shape.begin(), shape.end(), in.begin());
shaper_->forward(out.begin(), in.begin());
std::vector<DimT> slist(rank_cap);
std::transform(out.begin(), out.end(), slist.begin(),
[](CDimT cd) -> DimT
{
if (cd < 0)
{
cd = -cd - 1;
}
return std::round(cd);
});
return Shape(slist);
}

TensptrT get_tensor (void) const
{
Expand All @@ -65,18 +83,6 @@ struct MappedTensor final
return coorder_;
}

/// Return MappedTesnor connecting this instance to lhs'
/// shaper and coorder info
MappedTensor connect (MappedTensor lhs) const;

/// Return MappedTensor taking input tens and reverse of
/// this instance's shaper and coorder info
MappedTensor reverse (TensptrT tens) const
{
return MappedTensor(tens, CoordptrT(shaper_->reverse()),
!map_io_, coorder_);
}

private:
/// Tensor reference
TensptrT tensor_;
Expand Down
42 changes: 0 additions & 42 deletions ade/src/cmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,6 @@
namespace ade
{

static Shape calc_shape (CoordptrT shaper, const Shape& shape)
{
CoordT out;
CoordT in;
std::copy(shape.begin(), shape.end(), in.begin());
shaper->forward(out.begin(), in.begin());
std::vector<DimT> slist(rank_cap);
std::transform(out.begin(), out.end(), slist.begin(),
[](CDimT cd) -> DimT
{
if (cd < 0)
{
cd = -cd - 1;
}
return std::round(cd);
});
return Shape(slist);
}

Shape MappedTensor::shape (void) const
{
return calc_shape(shaper_, tensor_->shape());
}

MappedTensor MappedTensor::connect (MappedTensor lhs) const
{
CoordptrT outshaper(shaper_->connect(*lhs.get_shaper()));
Shape inshape = tensor_->shape();
Shape outshape = calc_shape(outshaper, inshape);
bool outmap_io = inshape.n_elems() > outshape.n_elems();
CoordptrT rhs_coorder = outmap_io == map_io_ ? coorder_ :
CoordptrT(coorder_->reverse());
CoordptrT lhs_coorder = lhs.get_coorder();
if (outmap_io != lhs.map_io())
{
lhs_coorder = CoordptrT(lhs_coorder->reverse());
}
CoordptrT outcoorder(outmap_io ? rhs_coorder->connect(*lhs_coorder) :
lhs_coorder->connect(*rhs_coorder));
return MappedTensor(tensor_, outshaper, outmap_io, outcoorder);
}

MappedTensor identity_map (TensptrT tensor)
{
return MappedTensor(tensor, ade::identity);
Expand Down
25 changes: 18 additions & 7 deletions bwd/src/grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace age

void Grader::visit (ade::iFunctor* func)
{
const ade::Opcode sum = rules_->sum_opcode();
const ade::Opcode prod = rules_->prod_opcode();
if (func == target_)
{
derivatives_.emplace(func, rules_->data(1, target_->shape()));
Expand Down Expand Up @@ -48,7 +50,7 @@ void Grader::visit (ade::iFunctor* func)
{
ade::TensT& gargs = grads[parent];
ade::TensptrT bwd = gargs.size() > 1 ? ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(), to_args(gargs))) :
ade::Functor::get(sum, to_args(gargs))) :
gargs[0];

auto& grad_indices = pathmap[parent];
Expand All @@ -62,7 +64,9 @@ void Grader::visit (ade::iFunctor* func)
{
ade::TensT args;
ade::MappedTensor& child = children[i];
ade::MappedTensor mapped_bwd = child.reverse(bwd);
ade::CoordptrT revshaper(child.get_shaper()->reverse());
bool revmapper = !child.map_io();
ade::CoordptrT revcoorder = child.get_coorder();
for (size_t j = 0; j < nchildren; ++j)
{
ade::MappedTensor& kid = children[j];
Expand All @@ -74,19 +78,26 @@ void Grader::visit (ade::iFunctor* func)
else
{
// reverse children[j] to child's shape/coord space
args.push_back(ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(), {
kid.connect(mapped_bwd)})));
args.push_back(ade::TensptrT(ade::Functor::get(sum, {
ade::MappedTensor(
ade::TensptrT(ade::Functor::get(sum, {kid})),
revshaper,
revmapper,
revcoorder)
})));
}
}
// pass down forward-gradient pair
ade::TensptrT grad(rules_->grad_rule(parent, args, i));

grads[child.get_tensor().get()].push_back(ade::TensptrT(
ade::Functor::get(rules_->prod_opcode(), {
ade::Functor::get(prod, {
ade::identity_map(grad),
ade::identity_map(ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(), {mapped_bwd})
ade::Functor::get(sum, {
ade::MappedTensor(bwd, revshaper,
revmapper, revcoorder)
})
)),
})));
}
Expand Down
Loading

0 comments on commit fc45269

Please sign in to comment.