New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpikeSlabRBM #1046
SpikeSlabRBM #1046
Changes from 54 commits
5a51869
4836743
529a299
8c8a236
0ea3dfd
6bf07ef
eca7e32
acabb9f
44b0866
d223942
debb798
8a40736
5dbc0eb
5fcd80c
f78c6bc
7393841
6d517a0
c9cb31b
d43efa9
030903a
0a962fd
65a5cde
88355b9
2d7a1b8
a3eda16
8b575c3
8724892
9c69a8f
e23b556
bc82890
7058511
2b72481
45c661f
256b982
55ab816
6b4b75e
dee6784
e714528
0a942a8
05c04e5
b4882b0
9c1c463
26f980b
f995c35
f06bc3b
c8e12d5
09d1e1c
5669fe4
9f24c35
5616e4f
e92c53a
7425f1f
52aaaec
a5f4278
bbe6433
e0e85ac
d2979bd
59c13d7
06a9756
67a4002
bfc2adf
57199b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,4 +27,5 @@ | |
#include "concat.hpp" | ||
#include "vr_class_reward.hpp" | ||
|
||
|
||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
/** | ||
* @file rbm.hpp | ||
* @author Kris Singh | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_RBM_HPP | ||
#define MLPACK_METHODS_ANN_RBM_HPP | ||
|
||
#include <mlpack/core.hpp> | ||
#include <mlpack/prereqs.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
#include <mlpack/core/math/random.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as for |
||
|
||
#include <mlpack/methods/ann/activation_functions/softplus_function.hpp> | ||
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp> | ||
#include <mlpack/methods/ann/init_rules/random_init.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like this can be removed, since it's unused, maybe I missed something? |
||
|
||
namespace mlpack { | ||
namespace ann /** Artificial neural networks. */ { | ||
|
||
template<typename InitializationRuleType, typename RBMPolicy> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you comment on both template parameter, using |
||
class RBM | ||
{ | ||
public: | ||
using NetworkType = RBM<InitializationRuleType, RBMPolicy>; | ||
typedef typename RBMPolicy::ElemType eT; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would go with |
||
|
||
/* | ||
* Intalise all the parameters of the network | ||
* using the intialise rule. | ||
* | ||
* @tparam IntialiserType rule to intialise the parameters of the network | ||
* @param predictors training data | ||
* @param numSteps Number of gibbs steps sampling | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the comment is out of date since the number of arguments has been changed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You didn't describe |
||
* @param useMonitoringCost evaluation function to use | ||
* @param persistence indicates to use persistent CD | ||
*/ | ||
RBM(arma::Mat<eT> predictors, InitializationRuleType initializeRule, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you put the |
||
RBMPolicy rbmPolicy, | ||
const size_t numSteps = 1, | ||
const size_t mSteps = 1, | ||
const bool useMonitoringCost = true, | ||
const bool persistence = false); | ||
|
||
// Reset the network | ||
void Reset(); | ||
|
||
/* | ||
* Train the network using the Opitimzer with given set of args. | ||
* the optimiser sets the parameters of the network for providing | ||
* most likely parameters given the inputs | ||
* @param: predictors data points | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe that's not an issue but I am not sure that doxygen processes the colon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also can you use proper grammar and punctuation, for more information take a look at: https://github.com/mlpack/mlpack/wiki/DesignGuidelines#comments in most cases this means adding |
||
* @param: optimizer Optimizer type | ||
*/ | ||
template<typename OptimizerType> | ||
void Train(const arma::Mat<eT>& predictors, OptimizerType& optimizer); | ||
|
||
/** | ||
* Evaluate the rbm network with the given parameters. | ||
* The function is needed for monitoring the progress of the network. | ||
* | ||
* @param parameters Matrix model parameters. | ||
* @param i Index of point to use for objective function evaluation. | ||
*/ | ||
double Evaluate(const arma::Mat<eT>& parameters, const size_t i); | ||
|
||
/** | ||
* This function calculates | ||
* the free energy of the model | ||
* @param: input data point | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The colon is not needed. |
||
*/ | ||
double FreeEnergy(arma::Mat<eT>&& input); | ||
|
||
/* | ||
* This functions samples the hidden | ||
* layer given the visible layer | ||
* | ||
* @param input visible layer input | ||
* @param output the sampled hidden layer | ||
*/ | ||
void SampleHidden(arma::Mat<eT>&& input, arma::Mat<eT>&& output); | ||
|
||
/* | ||
* This functions samples the visible | ||
* layer given the hidden layer | ||
* | ||
* @param input hidden layer | ||
* @param output the sampled visible layer | ||
*/ | ||
void SampleVisible(arma::Mat<eT>&& input, arma::Mat<eT>&& output); | ||
|
||
/* | ||
* This function does the k-step | ||
* gibbs sampling. | ||
* | ||
* @param input: input to the gibbs function | ||
* @param output: stores the negative sample | ||
* @param steps: number of gibbs sampling steps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, the colons are not needed. Maybe doxygen processes that correctly but I am not sure. Next time, I'll not comment that. So, make sure that doxygen works correctly or remove all unnecessary colons. |
||
*/ | ||
void Gibbs(arma::Mat<eT>&& input, arma::Mat<eT>&& output, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you align the parameter with the rest and use a new line for the output parameter, if the line ends up being longer than 80 characters. |
||
size_t steps = SIZE_MAX); | ||
|
||
/* | ||
* Calculates the gradients for the rbm network | ||
* | ||
* @param parameters the current parmaeters of the network | ||
* @param input index the visible layer/data point | ||
* @param output store the gradients | ||
*/ | ||
void Gradient(arma::Mat<eT>& parameters, const size_t input, | ||
arma::Mat<eT>& output); | ||
|
||
//! Return the number of separable functions (the number of predictor points). | ||
size_t NumFunctions() const { return numFunctions; } | ||
|
||
//! Return the number of stes of gibbs sampling. | ||
size_t NumSteps() const { return numSteps; } | ||
|
||
//! Return the parameters of the network | ||
const arma::Mat<eT>& Parameters() const { return parameter; } | ||
//! Modify the parameters of the network | ||
arma::Mat<eT>& Parameters() { return parameter; } | ||
|
||
//! Retutrn the rbm policy for the network | ||
const RBMPolicy& Policy() const { return rbmPolicy; } | ||
//! Modify the rbm policy for the network | ||
RBMPolicy& Policy() { return rbmPolicy; } | ||
|
||
//! Serialize the model. | ||
template<typename Archive> | ||
void Serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Locally stored parameters of the network | ||
arma::Mat<eT> parameter; | ||
//! Policy type of RBM | ||
RBMPolicy rbmPolicy; | ||
//! The matrix of data points (predictors). | ||
arma::Mat<eT> predictors; | ||
// Intialiser | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you replace |
||
InitializationRuleType initializeRule; | ||
//! Locally-stored state of the persistent cdk. | ||
arma::Mat<eT> state; | ||
//! Locally-stored number of data points | ||
size_t numFunctions; | ||
//! Locally-stored number of steps in gibbs sampling | ||
size_t numSteps; | ||
//! Locally-stored number of negative samples | ||
size_t mSteps; | ||
//! Locally-stored monitoring cost | ||
bool useMonitoringCost; | ||
//! Locally-stored persistent cd-k or not | ||
bool persistence; | ||
//! Locally-stored reset variable | ||
bool reset; | ||
|
||
//! Locally-stored reconstructed output from hidden layer | ||
arma::Mat<eT> hiddenReconstruction; | ||
//! Locally-stored reconstructed output from visible layer | ||
arma::Mat<eT> visibleReconstruction; | ||
|
||
//! Locally-stored negative samples from gibbs Distribution | ||
arma::Mat<eT> negativeSamples; | ||
//! Locally-stored gradients from the negative phase | ||
arma::Mat<eT> negativeGradient; | ||
//! Locally-stored temproray negative gradient used for negative phase | ||
arma::Mat<eT> tempNegativeGradient; | ||
//! Locally-stored gradient for positive phase | ||
arma::Mat<eT> positiveGradient; | ||
//! Locally-stored temporary output of gibbs chain | ||
arma::Mat<eT> gibbsTemporary; | ||
//! Locally-stored output of the preActivation function used in FreeEnergy | ||
arma::Mat<eT> preActivation; | ||
}; | ||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
#include "rbm_impl.hpp" | ||
|
||
#endif // MLPACK_METHODS_ANN_RBM_HPP |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Define the files we need to compile | ||
# Anything not in this list will not be compiled into mlpack. | ||
set(SOURCES | ||
binary_rbm_policy.hpp | ||
binary_rbm_policy_impl.hpp | ||
spike_slab_rbm_policy.hpp | ||
spike_slab_rbm_policy_impl.hpp | ||
) | ||
|
||
# Add directory name to sources. | ||
set(DIR_SRCS) | ||
foreach(file ${SOURCES}) | ||
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file}) | ||
endforeach() | ||
# Append sources (with directory name) to list of all mlpack sources (used at | ||
# the parent scope). | ||
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure the extra line here is necessary.