Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpikeSlabRBM #1046

Closed
wants to merge 62 commits into from
Closed
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
5a51869
Binary Rbm
kris-singh Jun 14, 2017
4836743
Make minor fixes with changed version of sharing parameters
kris-singh Jun 16, 2017
529a299
Add test for training the model
kris-singh Jun 17, 2017
8c8a236
Add weight tranpose in hidden layer
kris-singh Jun 18, 2017
0ea3dfd
Add Batch Training& Gibbs
kris-singh Jun 19, 2017
6bf07ef
Fix Style
kris-singh Jun 19, 2017
eca7e32
Minor Style Fix
kris-singh Jun 19, 2017
acabb9f
Merge Master
kris-singh Jun 19, 2017
44b0866
Minor Style Fix and Merge with Master
kris-singh Jun 19, 2017
d223942
Change the evaluation function
kris-singh Jun 21, 2017
debb798
Fix training
kris-singh Jun 21, 2017
8a40736
Add Monitor Cost & Change Sampling Procedure
kris-singh Jun 23, 2017
5dbc0eb
Change the Gradient Function, Remove Copy Refrences
kris-singh Jun 24, 2017
5fcd80c
Add Tests
kris-singh Jun 27, 2017
f78c6bc
Merge branch 'master' into Binary_RBM
kris-singh Jun 28, 2017
7393841
Use minibatch sgd
kris-singh Jun 28, 2017
6d517a0
Merge master
kris-singh Jun 29, 2017
c9cb31b
Style Fix
kris-singh Jun 29, 2017
d43efa9
Delete cdk_test
kris-singh Jun 29, 2017
030903a
Fix intializer
kris-singh Jun 29, 2017
0a962fd
Fix serialisation test
kris-singh Jun 30, 2017
65a5cde
Minor Style Fix
kris-singh Jun 30, 2017
88355b9
Add ssRbm layer
kris-singh Jul 3, 2017
2d7a1b8
Fix the classification test
kris-singh Jul 4, 2017
a3eda16
Rename dataset, Remove old dataset
kris-singh Jul 7, 2017
8b575c3
Add Binary Data
kris-singh Jul 8, 2017
8724892
Remove std::cout form network test
kris-singh Jul 8, 2017
9c69a8f
Merge branch 'ssRBM' of https://github.com/kris-singh/mlpack into ssRBM
kris-singh Jul 8, 2017
e23b556
Refatctor BinaryRbm & Add ssRBM
kris-singh Jul 12, 2017
bc82890
Style Fix
kris-singh Jul 12, 2017
7058511
Fix ClassifierTest
kris-singh Jul 12, 2017
2b72481
Fix serialisation test
kris-singh Jul 16, 2017
45c661f
Add ssRBM train Test
kris-singh Jul 16, 2017
256b982
Fix RBM Test
kris-singh Jul 16, 2017
55ab816
Change Reset function
kris-singh Jul 19, 2017
6b4b75e
Fix Lozhnikov Comments
kris-singh Jul 21, 2017
dee6784
Merge branch 'master' into ssRBM
kris-singh Jul 21, 2017
e714528
Fix Reset Function
kris-singh Jul 21, 2017
0a942a8
Add Classification Test
kris-singh Jul 21, 2017
05c04e5
Merge remote-tracking branch 'upstream/master' into ssRBM
lozhnikov Jul 24, 2017
b4882b0
Merged RBM layers and RBM policies. Fixed the spike slab RBM classifi…
lozhnikov Jul 24, 2017
9c1c463
Move the RBM implementation back to mlpack/methods/ann
lozhnikov Jul 24, 2017
26f980b
Fixed an error in SpikeSlabRBMPolicy::SampleVisible().
lozhnikov Jul 25, 2017
f995c35
Fixed an error in SpikeSlabRBMPolicy::VisibleMean().
lozhnikov Jul 25, 2017
f06bc3b
Add Mikhail's Fix
kris-singh Jul 26, 2017
c8e12d5
Fix comments from mikhail
kris-singh Jul 28, 2017
09d1e1c
Add template types to policy
kris-singh Aug 2, 2017
5669fe4
Fix slabMean
kris-singh Aug 2, 2017
9f24c35
Fix Template Types and Remove ScalarVisiblePenalty
kris-singh Aug 4, 2017
5616e4f
Some Stype changes and Comments fix
kris-singh Aug 7, 2017
e92c53a
Add Float Test
kris-singh Aug 7, 2017
7425f1f
Fix Float test
kris-singh Aug 7, 2017
52aaaec
Fix style errors
kris-singh Aug 7, 2017
a5f4278
Fix comments
kris-singh Aug 7, 2017
bbe6433
Style Fixes
kris-singh Aug 7, 2017
e0e85ac
Change parameters to ssRBM Test add std::fabs to freeEnergy
kris-singh Aug 7, 2017
d2979bd
Fix test and add trun_exp
kris-singh Aug 10, 2017
59c13d7
Merge master
kris-singh Aug 10, 2017
06a9756
Minor Fix
kris-singh Aug 11, 2017
67a4002
Minor Fix of comments and FreeEnergy function loop removal
kris-singh Aug 18, 2017
bfc2adf
Reduce Test time
kris-singh Aug 25, 2017
57199b9
Fix test
kris-singh Aug 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/mlpack/core/math/random.hpp
Expand Up @@ -60,6 +60,17 @@ inline double Random(const double lo, const double hi)
return lo + (hi - lo) * randUniformDist(randGen);
}

/**
* Generates a 0/1 specified by the input.
*/
inline double RandBernoulli(const double input)
{
if (Random() < input)
return 1;
else
return 0;
}

/**
* Generates a uniform random integer.
*/
Expand Down
3 changes: 3 additions & 0 deletions src/mlpack/methods/ann/CMakeLists.txt
Expand Up @@ -3,6 +3,8 @@
set(SOURCES
ffn.hpp
ffn_impl.hpp
rbm.hpp
rbm_impl.hpp
rnn.hpp
rnn_impl.hpp
)
Expand All @@ -21,3 +23,4 @@ add_subdirectory(activation_functions)
add_subdirectory(init_rules)
add_subdirectory(layer)
add_subdirectory(convolution_rules)
add_subdirectory(rbm)
10 changes: 6 additions & 4 deletions src/mlpack/methods/ann/init_rules/gaussian_init.hpp
Expand Up @@ -47,13 +47,14 @@ class GaussianInitialization
* @param rows Number of rows.
* @param cols Number of columns.
*/
void Initialize(arma::mat& W,
template<typename eT>
void Initialize(arma::Mat<eT>& W,
const size_t rows,
const size_t cols)
{
if (W.is_empty())
{
W = arma::mat(rows, cols);
W = arma::Mat<eT>(rows, cols);
}
W.imbue( [&]() { return arma::as_scalar(RandNormal(mean, variance)); } );
}
Expand All @@ -66,12 +67,13 @@ class GaussianInitialization
* @param cols Number of columns.
* @param slice Numbers of slices.
*/
void Initialize(arma::cube & W,
template<typename eT>
void Initialize(arma::Cube<eT> & W,
const size_t rows,
const size_t cols,
const size_t slices)
{
W = arma::cube(rows, cols, slices);
W = arma::Cube<eT>(rows, cols, slices);

for (size_t i = 0; i < slices; i++)
Initialize(W.slice(i), rows, cols);
Expand Down
1 change: 1 addition & 0 deletions src/mlpack/methods/ann/layer/layer.hpp
Expand Up @@ -27,4 +27,5 @@
#include "concat.hpp"
#include "vr_class_reward.hpp"


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the extra line here is necessary.

#endif
2 changes: 1 addition & 1 deletion src/mlpack/methods/ann/layer/linear.hpp
Expand Up @@ -45,7 +45,7 @@ class Linear
* @param inSize The number of input units.
* @param outSize The number of output units.
*/
Linear(const size_t inSize, const size_t outSize);;
Linear(const size_t inSize, const size_t outSize);

/*
* Reset the layer parameter.
Expand Down
183 changes: 183 additions & 0 deletions src/mlpack/methods/ann/rbm.hpp
@@ -0,0 +1,183 @@
/**
* @file rbm.hpp
* @author Kris Singh
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_RBM_HPP
#define MLPACK_METHODS_ANN_RBM_HPP

#include <mlpack/core.hpp>
#include <mlpack/prereqs.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

core.hpp already includes prereqs.hpp, so we can remove the header here.

#include <mlpack/core/math/random.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for prereqs.hpp, core.hpp already includes random.hpp.


#include <mlpack/methods/ann/activation_functions/softplus_function.hpp>
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp>
#include <mlpack/methods/ann/init_rules/random_init.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this can be removed, since it's unused, maybe I missed something?


namespace mlpack {
namespace ann /** Artificial neural networks. */ {

template<typename InitializationRuleType, typename RBMPolicy>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on both template parameter, using @tparam ...?

class RBM
{
public:
using NetworkType = RBM<InitializationRuleType, RBMPolicy>;
typedef typename RBMPolicy::ElemType eT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go with typedef typename VecType::elem_type ElemType;to be consistent with the rest of the codebase, for more information take a look at: https://github.com/mlpack/mlpack/blob/master/doc/policies/elemtype.hpp


/*
* Intalise all the parameters of the network
* using the intialise rule.
*
* @tparam IntialiserType rule to intialise the parameters of the network
* @param predictors training data
* @param numSteps Number of gibbs steps sampling
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the comment is out of date since the number of arguments has been changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't describe mSteps.

* @param useMonitoringCost evaluation function to use
* @param persistence indicates to use persistent CD
*/
RBM(arma::Mat<eT> predictors, InitializationRuleType initializeRule,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put the InitializationRuleType on a new line?

RBMPolicy rbmPolicy,
const size_t numSteps = 1,
const size_t mSteps = 1,
const bool useMonitoringCost = true,
const bool persistence = false);

// Reset the network
void Reset();

/*
* Train the network using the Opitimzer with given set of args.
* the optimiser sets the parameters of the network for providing
* most likely parameters given the inputs
* @param: predictors data points
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe that's not an issue but I am not sure that doxygen processes the colon.
@param:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can you use proper grammar and punctuation, for more information take a look at: https://github.com/mlpack/mlpack/wiki/DesignGuidelines#comments in most cases this means adding . at the end of the line and to start with an upper-case letter.

* @param: optimizer Optimizer type
*/
template<typename OptimizerType>
void Train(const arma::Mat<eT>& predictors, OptimizerType& optimizer);

/**
* Evaluate the rbm network with the given parameters.
* The function is needed for monitoring the progress of the network.
*
* @param parameters Matrix model parameters.
* @param i Index of point to use for objective function evaluation.
*/
double Evaluate(const arma::Mat<eT>& parameters, const size_t i);

/**
* This function calculates
* the free energy of the model
* @param: input data point
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The colon is not needed.

*/
double FreeEnergy(arma::Mat<eT>&& input);

/*
* This functions samples the hidden
* layer given the visible layer
*
* @param input visible layer input
* @param output the sampled hidden layer
*/
void SampleHidden(arma::Mat<eT>&& input, arma::Mat<eT>&& output);

/*
* This functions samples the visible
* layer given the hidden layer
*
* @param input hidden layer
* @param output the sampled visible layer
*/
void SampleVisible(arma::Mat<eT>&& input, arma::Mat<eT>&& output);

/*
* This function does the k-step
* gibbs sampling.
*
* @param input: input to the gibbs function
* @param output: stores the negative sample
* @param steps: number of gibbs sampling steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, the colons are not needed. Maybe doxygen processes that correctly but I am not sure. Next time, I'll not comment that. So, make sure that doxygen works correctly or remove all unnecessary colons.

*/
void Gibbs(arma::Mat<eT>&& input, arma::Mat<eT>&& output,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you align the parameter with the rest and use a new line for the output parameter, if the line ends up being longer than 80 characters.

size_t steps = SIZE_MAX);

/*
* Calculates the gradients for the rbm network
*
* @param parameters the current parmaeters of the network
* @param input index the visible layer/data point
* @param output store the gradients
*/
void Gradient(arma::Mat<eT>& parameters, const size_t input,
arma::Mat<eT>& output);

//! Return the number of separable functions (the number of predictor points).
size_t NumFunctions() const { return numFunctions; }

//! Return the number of stes of gibbs sampling.
size_t NumSteps() const { return numSteps; }

//! Return the parameters of the network
const arma::Mat<eT>& Parameters() const { return parameter; }
//! Modify the parameters of the network
arma::Mat<eT>& Parameters() { return parameter; }

//! Retutrn the rbm policy for the network
const RBMPolicy& Policy() const { return rbmPolicy; }
//! Modify the rbm policy for the network
RBMPolicy& Policy() { return rbmPolicy; }

//! Serialize the model.
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally stored parameters of the network
arma::Mat<eT> parameter;
//! Policy type of RBM
RBMPolicy rbmPolicy;
//! The matrix of data points (predictors).
arma::Mat<eT> predictors;
// Intialiser
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you replace // by //! in order to allow doxygen to precess the comments above?

InitializationRuleType initializeRule;
//! Locally-stored state of the persistent cdk.
arma::Mat<eT> state;
//! Locally-stored number of data points
size_t numFunctions;
//! Locally-stored number of steps in gibbs sampling
size_t numSteps;
//! Locally-stored number of negative samples
size_t mSteps;
//! Locally-stored monitoring cost
bool useMonitoringCost;
//! Locally-stored persistent cd-k or not
bool persistence;
//! Locally-stored reset variable
bool reset;

//! Locally-stored reconstructed output from hidden layer
arma::Mat<eT> hiddenReconstruction;
//! Locally-stored reconstructed output from visible layer
arma::Mat<eT> visibleReconstruction;

//! Locally-stored negative samples from gibbs Distribution
arma::Mat<eT> negativeSamples;
//! Locally-stored gradients from the negative phase
arma::Mat<eT> negativeGradient;
//! Locally-stored temproray negative gradient used for negative phase
arma::Mat<eT> tempNegativeGradient;
//! Locally-stored gradient for positive phase
arma::Mat<eT> positiveGradient;
//! Locally-stored temporary output of gibbs chain
arma::Mat<eT> gibbsTemporary;
//! Locally-stored output of the preActivation function used in FreeEnergy
arma::Mat<eT> preActivation;
};
} // namespace ann
} // namespace mlpack

#include "rbm_impl.hpp"

#endif // MLPACK_METHODS_ANN_RBM_HPP
17 changes: 17 additions & 0 deletions src/mlpack/methods/ann/rbm/CMakeLists.txt
@@ -0,0 +1,17 @@
# Define the files we need to compile
# Anything not in this list will not be compiled into mlpack.
set(SOURCES
binary_rbm_policy.hpp
binary_rbm_policy_impl.hpp
spike_slab_rbm_policy.hpp
spike_slab_rbm_policy_impl.hpp
)

# Add directory name to sources.
set(DIR_SRCS)
foreach(file ${SOURCES})
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
endforeach()
# Append sources (with directory name) to list of all mlpack sources (used at
# the parent scope).
set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)