Skip to content

Commit

Permalink
fix bwd for mappedtensors dis-similar shaper and coorder
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jan 2, 2019
1 parent 51506e4 commit 18e97c2
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 48 deletions.
34 changes: 14 additions & 20 deletions ade/cmap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct MappedTensor final
logs::fatal("cannot map a null tensor");
}
map_io_ = tensor_->shape().n_elems() > shape().n_elems();
if (shaper == ade::identity || map_io_)
if (shaper == identity || map_io_)
{
coorder_ = shaper;
}
Expand All @@ -42,25 +42,7 @@ struct MappedTensor final
}

/// Return shape of tensor filtered through coordinate mapper
Shape shape (void) const
{
const Shape& shape = tensor_->shape();
CoordT out;
CoordT in;
std::copy(shape.begin(), shape.end(), in.begin());
shaper_->forward(out.begin(), in.begin());
std::vector<DimT> slist(rank_cap);
std::transform(out.begin(), out.end(), slist.begin(),
[](CDimT cd) -> DimT
{
if (cd < 0)
{
cd = -cd - 1;
}
return std::round(cd);
});
return Shape(slist);
}
Shape shape (void) const;

TensptrT get_tensor (void) const
{
Expand All @@ -83,6 +65,18 @@ struct MappedTensor final
return coorder_;
}

/// Return MappedTesnor connecting this instance to lhs'
/// shaper and coorder info
MappedTensor connect (MappedTensor lhs) const;

/// Return MappedTensor taking input tens and reverse of
/// this instance's shaper and coorder info
MappedTensor reverse (TensptrT tens) const
{
return MappedTensor(tens, CoordptrT(shaper_->reverse()),
!map_io_, coorder_);
}

private:
/// Tensor reference
TensptrT tensor_;
Expand Down
42 changes: 42 additions & 0 deletions ade/src/cmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,48 @@
namespace ade
{

static Shape calc_shape (CoordptrT shaper, const Shape& shape)
{
CoordT out;
CoordT in;
std::copy(shape.begin(), shape.end(), in.begin());
shaper->forward(out.begin(), in.begin());
std::vector<DimT> slist(rank_cap);
std::transform(out.begin(), out.end(), slist.begin(),
[](CDimT cd) -> DimT
{
if (cd < 0)
{
cd = -cd - 1;
}
return std::round(cd);
});
return Shape(slist);
}

Shape MappedTensor::shape (void) const
{
return calc_shape(shaper_, tensor_->shape());
}

MappedTensor MappedTensor::connect (MappedTensor lhs) const
{
CoordptrT outshaper(shaper_->connect(*lhs.get_shaper()));
Shape inshape = tensor_->shape();
Shape outshape = calc_shape(outshaper, inshape);
bool outmap_io = inshape.n_elems() > outshape.n_elems();
CoordptrT rhs_coorder = outmap_io == map_io_ ? coorder_ :
CoordptrT(coorder_->reverse());
CoordptrT lhs_coorder = lhs.get_coorder();
if (outmap_io != lhs.map_io())
{
lhs_coorder = CoordptrT(lhs_coorder->reverse());
}
CoordptrT outcoorder(outmap_io ? rhs_coorder->connect(*lhs_coorder) :
lhs_coorder->connect(*rhs_coorder));
return MappedTensor(tensor_, outshaper, outmap_io, outcoorder);
}

MappedTensor identity_map (TensptrT tensor)
{
return MappedTensor(tensor, ade::identity);
Expand Down
10 changes: 3 additions & 7 deletions bwd/src/grader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void Grader::visit (ade::iFunctor* func)
{
ade::TensT args;
ade::MappedTensor& child = children[i];
ade::CoordptrT bwd_shaper(child.get_shaper()->reverse());
ade::MappedTensor mapped_bwd = child.reverse(bwd);
for (size_t j = 0; j < nchildren; ++j)
{
ade::MappedTensor& kid = children[j];
Expand All @@ -73,12 +73,10 @@ void Grader::visit (ade::iFunctor* func)
}
else
{
ade::CoordptrT shaper(kid.get_shaper()->connect(*bwd_shaper));
// reverse children[j] to child's shape/coord space
args.push_back(ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(), {
ade::MappedTensor(tens, shaper),
})));
kid.connect(mapped_bwd)})));
}
}
// pass down forward-gradient pair
Expand All @@ -88,9 +86,7 @@ void Grader::visit (ade::iFunctor* func)
ade::Functor::get(rules_->prod_opcode(), {
ade::identity_map(grad),
ade::identity_map(ade::TensptrT(
ade::Functor::get(rules_->sum_opcode(), {
ade::MappedTensor(bwd, bwd_shaper),
})
ade::Functor::get(rules_->sum_opcode(), {mapped_bwd})
)),
})));
}
Expand Down
Loading

0 comments on commit 18e97c2

Please sign in to comment.