Skip to content

Commit

Permalink
introduce a version of matmul which does not depend on first arg
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavg committed May 25, 2017
1 parent 55266dc commit 46c9a58
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 3 additions & 1 deletion dynet/dynet.h
Original file line number Diff line number Diff line change
Expand Up @@ -636,8 +636,10 @@ struct Node {

Device* device; /**< pointer to the node, or null to inherit device from first input, or default when there is no input */

unsigned matmul_count; // how many matmul nodes am I an arg of?

protected:
Node() : args(), device(default_device) {}
Node() : args(), device(default_device), matmul_count(0) {}
explicit Node(const std::initializer_list<VariableIndex>& a) : args(a), device(default_device) {}
template <typename T>
explicit Node(const T&c) : args(c.begin(), c.end()), device(default_device) {}
Expand Down
2 changes: 1 addition & 1 deletion dynet/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Expression operator+(const Expression& x, real y) { return y + x; }
Expression operator-(const Expression& x, const Expression& y) { return x + (-y); }
Expression operator-(real x, const Expression& y) { return Expression(y.pg, y.pg->add_function<ConstantMinusX>({y.i}, x)); }
Expression operator-(const Expression& x, real y) { return -(y - x); }
Expression operator*(const Expression& x, const Expression& y) { return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
Expression operator*(const Expression& x, const Expression& y) { x.pg->nodes[x.i]->matmul_count++; return Expression(x.pg, x.pg->add_function<MatrixMultiply>({x.i, y.i})); }
Expression operator*(const Expression& x, float y) { return Expression(x.pg, x.pg->add_function<ConstScalarMultiply>({x.i}, y)); }
Expression cmult(const Expression& x, const Expression& y) {
if (x.dim().batch_size() == 1)
Expand Down
15 changes: 12 additions & 3 deletions dynet/nodes-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -879,16 +879,25 @@ int MatrixMultiply::autobatch_sig(const ComputationGraph & cg, SigMap &sm) const
// TODO do we want to treat different dimensions of first/second arg differently?
if(dim.bd == 1) {
Sig s(nt::matmul);
s.add_node(args[0]);
// if arg0 is likely to be shared, include it in the sig.
// otherwise, include both args dims in the sig.
if (cg.nodes[args[0]]->matmul_count > 2) { //TODO why 2? can we set a better number?
s.add_node(args[0]); s.add_dim(cg.nodes[args[1]]->dim);
} else {
s.add_dim(cg.nodes[args[0]]->dim); s.add_dim(cg.nodes[args[1]]->dim);
}
return sm.get_idx(s);
} else {
return 0; // TODO handle the batched case as well? should it differ at all?
}
}

std::vector<int> MatrixMultiply::autobatch_concat(const ComputationGraph & cg) const {
vector<int> ret(args.size(), 0);
if (dim.bd == 1) { ret[1] = 1; }
vector<int> ret(2, 0);
if (dim.bd == 1) {
ret[1] = 1;
if (cg.nodes[args[0]]->matmul_count <= 2) { ret[0] = 1; }
}
return ret;
}

Expand Down

0 comments on commit 46c9a58

Please sign in to comment.