Skip to content

Commit

Permalink
Swig fixes (#1331)
Browse files Browse the repository at this point in the history
* Fix swig bindings

 - move dynet::detail to expr.h so swig bindings can use it
 - two small fixes (include nodes.h, drop default parameter)

* swig example: set learning rate instead of updateEpoch
  • Loading branch information
akoehn authored and neubig committed Mar 30, 2018
1 parent 060f451 commit 20fd674
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 23 deletions.
3 changes: 2 additions & 1 deletion contrib/swig/dynet_swig.i
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include <stdexcept>
#include "param-init.h"
#include "model.h"
#include "nodes.h"
#include "tensor.h"
#include "dynet.h"
#include "training.h"
Expand Down Expand Up @@ -618,7 +619,7 @@ inline Expression concatenate_cols(const std::vector<Expression>& xs) {
return detail::f<Concatenate>(xs, 1);
};

inline Expression concatenate(const std::vector<Expression>& xs, unsigned d = 0) {
inline Expression concatenate(const std::vector<Expression>& xs, unsigned d) {
return detail::f<Concatenate>(xs, d);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object XorScala {
ComputationGraph.backward(loss_expr)
sgd.update()
}
sgd.updateEpoch()
sgd.learningRate *= 0.998f
loss /= 4
println("iter = " + iter + ", loss = " + loss)
}
Expand Down
21 changes: 0 additions & 21 deletions dynet/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,27 +280,6 @@ Expression to_device(const Expression & x, Device *device) {
// Functions with variable argument lengths //
////////////////////////////////////////////////

namespace detail {
template <typename F, typename T>
inline Expression f(const T& xs) {
DYNET_ARG_CHECK(xs.size() > 0, "Zero-size argument passed to function");
ComputationGraph *pg = xs.begin()->pg;
std::vector<VariableIndex> xis(xs.size());
int i = 0;
for (auto xi = xs.begin(); xi != xs.end(); ++xi) xis[i++] = xi->i;
return Expression(pg, pg->add_function<F>(xis));
}
template <typename F, typename T, typename T1>
inline Expression f(const T& xs, const T1& arg1) {
DYNET_ARG_CHECK(xs.size() > 0, "Zero-size argument passed to function");
ComputationGraph *pg = xs.begin()->pg;
std::vector<VariableIndex> xis(xs.size());
int i = 0;
for (auto xi = xs.begin(); xi != xs.end(); ++xi) xis[i++] = xi->i;
return Expression(pg, pg->add_function<F>(xis, arg1));
}
} // namespace detail

Expression affine_transform(const std::initializer_list<Expression> &xs) { return detail::f<AffineTransform>(xs); }
Expression affine_transform(const std::vector<Expression> &xs) { return detail::f<AffineTransform>(xs); }

Expand Down
22 changes: 22 additions & 0 deletions dynet/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,28 @@ inline Expression operator*(float y, const Expression& x) { return x * y; }
*/
inline Expression operator/(const Expression& x, float y) { return x * (1.f / y); }

namespace detail {
template <typename F, typename T>
inline Expression f(const T& xs) {
DYNET_ARG_CHECK(xs.size() > 0, "Zero-size argument passed to function");
ComputationGraph *pg = xs.begin()->pg;
std::vector<VariableIndex> xis(xs.size());
int i = 0;
for (auto xi = xs.begin(); xi != xs.end(); ++xi) xis[i++] = xi->i;
return Expression(pg, pg->add_function<F>(xis));
}
template <typename F, typename T, typename T1>
inline Expression f(const T& xs, const T1& arg1) {
DYNET_ARG_CHECK(xs.size() > 0, "Zero-size argument passed to function");
ComputationGraph *pg = xs.begin()->pg;
std::vector<VariableIndex> xis(xs.size());
int i = 0;
for (auto xi = xs.begin(); xi != xs.end(); ++xi) xis[i++] = xi->i;
return Expression(pg, pg->add_function<F>(xis, arg1));
}
} // namespace detail


/**
* \ingroup arithmeticoperations
* \brief Matrix division
Expand Down

0 comments on commit 20fd674

Please sign in to comment.