Skip to content

Commit

Permalink
remove redundencies
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Jul 24, 2019
1 parent 66d5451 commit e8c31c1
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 67 deletions.
13 changes: 10 additions & 3 deletions ade/traveler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
namespace ade
{

/// Extremely generic traveler that visits every node in the graph once
struct OnceTraveler : public iTraveler
{
virtual ~OnceTraveler (void) = default;
Expand All @@ -42,10 +43,16 @@ struct OnceTraveler : public iTraveler
}
}

protected:
virtual void visit_leaf (iLeaf* leaf) = 0;
virtual void visit_leaf (iLeaf* leaf) {} // do nothing

virtual void visit_func (iFunctor* func) = 0;
virtual void visit_func (iFunctor* func)
{
auto& children = func->get_children();
for (auto child : children)
{
child.get_tensor()->accept(*this);
}
}

std::unordered_set<iTensor*> visited_;
};
Expand Down
7 changes: 3 additions & 4 deletions cfg/ead.yml
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ api:
fwd[ade::rank_cap][dimension] =
extent - arg->shape().at(dimension);
}),
std::make_shared<ead::CoordMap>(ead::SLICE, slicings, false)
std::make_shared<ead::CoordMap>(slicings, false)
)
});
- template: typename T
Expand Down Expand Up @@ -514,7 +514,7 @@ api:
fwd[ade::rank_cap][dimension] =
padding.first + padding.second;
}),
std::make_shared<ead::CoordMap>(ead::PAD, paddings, false)
std::make_shared<ead::CoordMap>(paddings, false)
)
});
- template: typename T
Expand Down Expand Up @@ -618,8 +618,7 @@ api:
return ead::make_functor<T>(ade::Opcode{"CONV",::age::CONV}, {
ead::FuncArg<T>(input, input_shaper, nullptr),
ead::FuncArg<T>(kernel, kernel_shaper,
std::make_shared<ead::CoordMap>(
ead::CONV, kernel_dims, true)),
std::make_shared<ead::CoordMap>(kernel_dims, true)),
});
- template: typename T
name: reduce_sum_1d
Expand Down
2 changes: 1 addition & 1 deletion dbg/grpc/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ struct InteractiveSession final : public ead::iSession
}

// basic copy over from session::update
ead::Traveler traveler;
ade::OnceTraveler traveler;
for (auto& tens : targeted)
{
tens->accept(traveler);
Expand Down
21 changes: 0 additions & 21 deletions ead/coord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,11 @@
namespace ead
{

enum TransCode
{
EXTEND = 0,
PERMUTE,
REDUCE,
CONV,
SLICE,
PAD,
};

struct CoordMap final : public ade::iCoordMap
{
CoordMap (ade::CoordT indices, bool bijective) :
indices_(indices), bijective_(bijective) {}

// todo: deprecate
CoordMap (TransCode transcode, ade::CoordT indices, bool bijective) :
transcode_(transcode), indices_(indices), bijective_(bijective) {}

ade::iCoordMap* connect (const ade::iCoordMap& rhs) const override
{
return nullptr;
Expand Down Expand Up @@ -53,14 +39,7 @@ struct CoordMap final : public ade::iCoordMap
return bijective_;
}

TransCode transcode (void) const
{
return transcode_;
}

private:
TransCode transcode_;

ade::CoordT indices_;

bool bijective_;
Expand Down
Binary file modified ead/data/graph.pb
Binary file not shown.
4 changes: 2 additions & 2 deletions ead/grader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ NodeptrT<T> reduce_grad (const ade::FuncArg& child,
bcast[d] = shape.at(d);
}
}
revcoord = std::make_shared<CoordMap>(EXTEND, bcast, false);
revcoord = std::make_shared<CoordMap>(bcast, false);
}
return make_functor<T>(ade::Opcode{"EXTEND",age::EXTEND}, {
FuncArg<T>(bwd, revshaper, revcoord)
Expand All @@ -66,7 +66,7 @@ NodeptrT<T> permute_grad (ade::iFunctor* fwd,
{
order[dims[i]] = i;
}
revcoord = std::make_shared<CoordMap>(PERMUTE, order, true);
revcoord = std::make_shared<CoordMap>(order, true);
}
return make_functor<T>(ade::Opcode{"PERMUTE",age::PERMUTE},{
FuncArg<T>(bwd, revshaper, revcoord)
Expand Down
9 changes: 3 additions & 6 deletions ead/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ struct EADSaver final : public pbm::iSaver
}
ade::CoordT coord;
mapper->forward(coord.begin(), coord.begin());
std::vector<double> out(coord.begin(), coord.end());
out.push_back(static_cast<CoordMap*>(mapper.get())->transcode());
return out;
return std::vector<double>(coord.begin(), coord.end());
}
};

Expand Down Expand Up @@ -158,16 +156,15 @@ struct EADLoader final : public pbm::iLoader
{
return nullptr;
}
if (ade::rank_cap + 1 != coord.size())
if (ade::rank_cap + 1 < coord.size())
{
logs::fatal("cannot deserialize non-vector coordinate map");
}
bool is_bijective = false == estd::has(non_bijectives, age::get_op(opname));
ade::CoordT indices;
auto cit = coord.begin();
std::copy(cit, cit + ade::rank_cap, indices.begin());
TransCode tcode = (TransCode) coord[ade::rank_cap];
return std::make_shared<CoordMap>(tcode, indices, is_bijective);
return std::make_shared<CoordMap>(indices, is_bijective);
}
};

Expand Down
28 changes: 1 addition & 27 deletions ead/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,6 @@ struct SizeT
operator size_t() const { return d; }
};

// todo: give this more reasons for existence
struct Traveler final : public ade::iTraveler
{
/// Implementation of iTraveler
void visit (ade::iLeaf* leaf) override
{
visited_.emplace(leaf);
}

/// Implementation of iTraveler
void visit (ade::iFunctor* func) override
{
if (false == estd::has(visited_, func))
{
visited_.emplace(func);
auto& children = func->get_children();
for (auto& child : children)
{
child.get_tensor()->accept(*this);
}
}
}

TensSetT visited_;
};

// for each leaf node, iteratively update the parents
// don't update parent node if it is part of ignored set
struct Session final : public iSession
Expand Down Expand Up @@ -154,7 +128,7 @@ struct Session final : public iSession

void update_target (TensSetT target, TensSetT updated = {}) override
{
Traveler targetted;
ade::OnceTraveler targetted;
for (auto& tens : target)
{
tens->accept(targetted);
Expand Down
6 changes: 3 additions & 3 deletions ead/src/coord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ CoordptrT reduce (std::vector<ade::RankT> red_dims)
auto it = rdims.begin();
std::fill(it, rdims.end(), ade::rank_cap);
std::copy(red_dims.begin(), red_dims.end(), it);
return std::make_shared<CoordMap>(REDUCE, rdims, false);
return std::make_shared<CoordMap>(rdims, false);
}

CoordptrT extend (ade::RankT rank, std::vector<ade::DimT> ext)
Expand All @@ -52,7 +52,7 @@ CoordptrT extend (ade::RankT rank, std::vector<ade::DimT> ext)
auto it = bcast.begin();
std::fill(it, bcast.end(), 1);
std::copy(ext.begin(), ext.end(), it + rank);
return std::make_shared<CoordMap>(EXTEND, bcast, false);
return std::make_shared<CoordMap>(bcast, false);
}

CoordptrT permute (std::vector<ade::RankT> dims)
Expand Down Expand Up @@ -84,7 +84,7 @@ CoordptrT permute (std::vector<ade::RankT> dims)

ade::CoordT order;
std::copy(dims.begin(), dims.end(), order.begin());
return std::make_shared<CoordMap>(PERMUTE, order, true);
return std::make_shared<CoordMap>(order, true);
}

}
Expand Down
29 changes: 29 additions & 0 deletions opt/parse/test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,15 @@ static std::vector<double> vectorize (::NumList* lst)
TEST(PARSE, SymbolFail)
{
const char* symbs = "symbol Apple Banana;\nsymbol Citrus Zucchini;";
const char* long_symb = "symbol AppleBananaCitrusZucchiniLemonGrapefruit;";

::PtrList* stmts = nullptr;
int status = ::parse_str(&stmts, symbs);
EXPECT_EQ(1, status);

::PtrList* stmts2 = nullptr;
int status2 = ::parse_str(&stmts2, long_symb);
EXPECT_EQ(1, status2);
}


Expand Down Expand Up @@ -143,6 +148,7 @@ TEST(PARSE, EdgeDef)
const char* shape_edge = "F(X={shaper:[4,5,6,7,8,9,10,11]})=>1;\n";
const char* coord_edge = "F(X={coorder:[8,8,8,8,8,8,8,8]})=>2;\n";
const char* both_edges = "F(X={coorder:[8,8,8,8,8,8,8,8],shaper:[4,5,6,7,8,9,10,11]})=>3;\n";
const char* both_edges2 = "F(X={shaper:[4,5,6,7,8,9,10,11],coorder:[8,8,8,8,8,8,8,8]})=>3;\n";
std::vector<double> expect_shaper = {4,5,6,7,8,9,10,11};
std::vector<double> expect_coorder = {8,8,8,8,8,8,8,8};

Expand Down Expand Up @@ -211,6 +217,29 @@ TEST(PARSE, EdgeDef)
coorder = vectorize(arg->coorder_);
EXPECT_ARREQ(expect_shaper, shaper);
EXPECT_ARREQ(expect_coorder, coorder);

ASSERT_EQ(0, ::parse_str(&stmts, both_edges2));
EXPECT_EQ(nullptr, stmts->head_->next_);
stmt = (::Statement*) stmts->head_->val_;
ASSERT_EQ(::CONVERSION, stmt->type_);
conv = (::Conversion*) stmt->val_;
src = conv->source_;
dest = conv->dest_;
ASSERT_EQ(::SCALAR, dest->type_);
EXPECT_EQ(3, dest->val_.scalar_);
ASSERT_EQ(::BRANCH, src->type_);
branch = src->val_.branch_;
EXPECT_STREQ("F", branch->label_);
EXPECT_EQ(nullptr, branch->args_->head_->next_);
arg = (::Arg*) branch->args_->head_->val_;
ASSERT_EQ(::ANY, arg->subgraph_->type_);
EXPECT_STREQ("X", arg->subgraph_->val_.any_);
ASSERT_NE(nullptr, arg->shaper_);
ASSERT_NE(nullptr, arg->coorder_);
shaper = vectorize(arg->shaper_);
coorder = vectorize(arg->coorder_);
EXPECT_ARREQ(expect_shaper, shaper);
EXPECT_ARREQ(expect_coorder, coorder);
}


Expand Down

0 comments on commit e8c31c1

Please sign in to comment.