Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
saksham189 committed Feb 12, 2019
1 parent b14e35a commit ff70d98
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 10 deletions.
10 changes: 8 additions & 2 deletions src/mlpack/methods/ann/visitor/backward_visitor.hpp
Expand Up @@ -30,8 +30,11 @@ class BackwardVisitor : public boost::static_visitor<void>
public:
//! Execute the Backward() function given the input, error and delta
//! parameter.
BackwardVisitor(arma::mat&& input, arma::mat&& error, arma::mat&& delta)

//! Execute the Backward() function for the layer with the specified index.
BackwardVisitor(arma::mat&& input, arma::mat&& error, arma::mat&& delta,
int layer = -1);
const size_t index);

//! Execute the Backward() function.
template<typename LayerType>
Expand All @@ -48,7 +51,10 @@ class BackwardVisitor : public boost::static_visitor<void>
arma::mat&& delta;

//! The index of the layer to run.
int index;
size_t index;

//! Indicates whether to use index or not
bool hasIndex;

//! Execute the Backward() function if the module does not have Run()
//! check.
Expand Down
19 changes: 16 additions & 3 deletions src/mlpack/methods/ann/visitor/backward_visitor_impl.hpp
Expand Up @@ -19,14 +19,27 @@ namespace mlpack {
namespace ann {

//! BackwardVisitor visitor class.
inline BackwardVisitor::BackwardVisitor(arma::mat&& input,
arma::mat&& error,
arma::mat&& delta) :
input(std::move(input)),
error(std::move(error)),
delta(std::move(delta)),
index(0),
hasIndex(false)
{
/* Nothing to do here. */
}

inline BackwardVisitor::BackwardVisitor(arma::mat&& input,
arma::mat&& error,
arma::mat&& delta,
int layer) :
const size_t index) :
input(std::move(input)),
error(std::move(error)),
delta(std::move(delta)),
index(layer)
index(index),
hasIndex(true)
{
/* Nothing to do here. */
}
Expand All @@ -50,7 +63,7 @@ inline typename std::enable_if<
HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
BackwardVisitor::LayerBackward(T* layer, arma::mat& /* input */) const
{
if (index == -1)
if (hasIndex)
{
layer->Backward(std::move(input), std::move(error),
std::move(delta));
Expand Down
10 changes: 8 additions & 2 deletions src/mlpack/methods/ann/visitor/gradient_visitor.hpp
Expand Up @@ -30,7 +30,10 @@ class GradientVisitor : public boost::static_visitor<void>
public:
//! Executes the Gradient() method of the given module using the input and
//! delta parameter.
GradientVisitor(arma::mat&& input, arma::mat&& delta, int layer = -1);
GradientVisitor(arma::mat&& input, arma::mat&& delta);

//! Executes the Gradient() method for the layer with the specified index.
GradientVisitor(arma::mat&& input, arma::mat&& delta, const size_t index);

//! Executes the Gradient() method.
template<typename LayerType>
Expand All @@ -44,7 +47,10 @@ class GradientVisitor : public boost::static_visitor<void>
arma::mat&& delta;

//! Index of the layer to run.
int index;
size_t index;

//! Indicates whether to use index or not
bool hasIndex;

//! Execute the Gradient() function if the module implements the Gradient()
//! function.
Expand Down
17 changes: 14 additions & 3 deletions src/mlpack/methods/ann/visitor/gradient_visitor_impl.hpp
Expand Up @@ -19,11 +19,22 @@ namespace mlpack {
namespace ann {

//! GradientVisitor visitor class.
inline GradientVisitor::GradientVisitor(arma::mat&& input, arma::mat&& delta) :
input(std::move(input)),
delta(std::move(delta)),
index(0),
hasIndex(false)
{
/* Nothing to do here. */
}


inline GradientVisitor::GradientVisitor(arma::mat&& input, arma::mat&& delta,
int layer) :
const size_t index) :
input(std::move(input)),
delta(std::move(delta)),
index(layer)
index(index),
hasIndex(true)
{
/* Nothing to do here. */
}
Expand All @@ -50,7 +61,7 @@ inline typename std::enable_if<
HasRunCheck<T, bool&(T::*)(void)>::value, void>::type
GradientVisitor::LayerGradients(T* layer, arma::mat& /* input */) const
{
if (index == -1)
if (hasIndex)
{
layer->Gradient(std::move(input), std::move(delta),
std::move(layer->Gradient()));
Expand Down

0 comments on commit ff70d98

Please sign in to comment.