Skip to content

Commit

Permalink
gaussian error function
Browse files Browse the repository at this point in the history
  • Loading branch information
redpony committed Dec 8, 2015
1 parent d2ef1d9 commit e1f5f3c
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions cnn/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Expression colwise_add(const Expression& x, const Expression& bias) { return Exp
Expression contract3d_1d(const Expression& x, const Expression& y) { return Expression(x.pg, x.pg->add_function<InnerProduct3D_1D>({x.i, y.i})); }
Expression contract3d_1d(const Expression& x, const Expression& y, const Expression& b) { return Expression(x.pg, x.pg->add_function<InnerProduct3D_1D>({x.i, y.i, b.i})); }

Expression erf(const Expression& x) { return Expression(x.pg, x.pg->add_function<Erf>({x.i})); }
Expression tanh(const Expression& x) { return Expression(x.pg, x.pg->add_function<Tanh>({x.i})); }
Expression log(const Expression& x) { return Expression(x.pg, x.pg->add_function<Log>({x.i})); }
Expression exp(const Expression& x) { return Expression(x.pg, x.pg->add_function<Exp>({x.i})); }
Expand Down
1 change: 1 addition & 0 deletions cnn/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Expression contract3d_1d(const Expression& x, const Expression& y);
// z_ij = x_ijk * y_k + b_ij
Expression contract3d_1d(const Expression& x, const Expression& y, const Expression& b);

Expression erf(const Expression& x);
Expression tanh(const Expression& x);
Expression exp(const Expression& x);
Expression square(const Expression& x);
Expand Down
12 changes: 12 additions & 0 deletions cnn/functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ struct FNegate {
}
};

struct FErf {
CNN_DEVICE_FUNC inline float operator()(float x) const {
return erff(x);
}
};

struct FTanh {
CNN_DEVICE_FUNC inline float operator()(float x) const {
#ifdef FAST_TANH
Expand All @@ -103,6 +109,12 @@ struct FMaxBackwardInv {
}
};

struct FErfBackward {
CNN_DEVICE_FUNC inline float operator()(float x, float d) const {
return 1.1283791670955125738961589f * expf(-x * x) * d;
}
};

struct FTanhBackward {
CNN_DEVICE_FUNC inline float operator()(float t, float d) const {
return (1.f - t * t) * d;
Expand Down
11 changes: 11 additions & 0 deletions cnn/nodes-common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ Dim Average::dim_forward(const vector<Dim>& xs) const {
return d;
}

string Erf::as_string(const vector<string>& arg_names) const {
ostringstream s;
s << "erf(" << arg_names[0] << ')';
return s.str();
}

Dim Erf::dim_forward(const vector<Dim>& xs) const {
assert(xs.size() == 1);
return xs[0];
}

string Tanh::as_string(const vector<string>& arg_names) const {
ostringstream s;
s << "tanh(" << arg_names[0] << ')';
Expand Down
14 changes: 14 additions & 0 deletions cnn/nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,20 @@ void Average::backward_impl(const vector<const Tensor*>& xs,
*dEdxi += (*dEdf / xs.size());
}

void Erf::forward_impl(const vector<const Tensor*>& xs, Tensor& fx) const {
auto x = **xs[0];
(*fx) = x.unaryExpr(FErf());
}

void Erf::backward_impl(const vector<const Tensor*>& xs,
const Tensor& fx,
const Tensor& dEdf,
unsigned i,
Tensor& dEdxi) const {
auto x = **xs[0];
*dEdxi += x.binaryExpr(*dEdf, FErfBackward());
}

void Tanh::forward_impl(const vector<const Tensor*>& xs, Tensor& fx) const {
#if HAVE_CUDA
gpu::vtanh(fx.d.size(), xs[0]->v, fx.v);
Expand Down
13 changes: 13 additions & 0 deletions cnn/nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ struct ConstantMinusX : public Node {
real c;
};

// y = erf x_1
struct Erf : public Node {
explicit Erf(const std::initializer_list<VariableIndex>& a) : Node(a) {}
std::string as_string(const std::vector<std::string>& arg_names) const override;
Dim dim_forward(const std::vector<Dim>& xs) const override;
void forward_impl(const std::vector<const Tensor*>& xs, Tensor& fx) const override;
void backward_impl(const std::vector<const Tensor*>& xs,
const Tensor& fx,
const Tensor& dEdf,
unsigned i,
Tensor& dEdxi) const override;
};

// y = tanh x_1
struct Tanh : public Node {
explicit Tanh(const std::initializer_list<VariableIndex>& a) : Node(a) {}
Expand Down

0 comments on commit e1f5f3c

Please sign in to comment.