New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generative Adversarial Network #1066
Changes from 56 commits
13b5200
ac7418e
ffd1cf8
c34883f
d5b23c2
e121b81
2d85c49
df48591
cfaea05
d714edb
e14e2fe
9f66c33
cb2ec96
f4752ca
b4b0239
d14404c
a35a445
8d872b9
3a366a1
0528c2e
83df332
1cd7938
87f9e72
f754852
6059798
8c267b6
9d32965
a4e9f4b
e7ca020
edcf4d2
526050b
df05c04
4bcf594
3227197
c6c8da8
ddc1bd8
1c1ada0
1a8b6db
8527d34
93ab8b2
04c51ca
ea407ba
1bae148
e99b1cf
fe58004
5b0fe4c
25cbbd2
09d88f4
5d276ce
29d63f5
ea46274
30b18fd
fa619fb
07972dd
1c94976
32f741b
521bc31
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
/** | ||
* @file gan.hpp | ||
* @author Kris Singh | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_GAN_HPP | ||
#define MLPACK_METHODS_ANN_GAN_HPP | ||
|
||
#include <mlpack/core.hpp> | ||
|
||
#include <mlpack/methods/ann/layer/layer.hpp> | ||
#include <mlpack/methods/ann/layer/base_layer.hpp> | ||
#include <mlpack/methods/ann/ffn.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct if I am mistaken, but it seems you needn't other includes except this one. |
||
#include <mlpack/methods/ann/visitor/output_parameter_visitor.hpp> | ||
#include <mlpack/methods/ann/visitor/reset_visitor.hpp> | ||
#include <mlpack/methods/ann/visitor/weight_size_visitor.hpp> | ||
#include <mlpack/methods/ann/visitor/weight_set_visitor.hpp> | ||
#include <mlpack/methods/ann/init_rules/random_init.hpp> | ||
|
||
using namespace mlpack; | ||
using namespace mlpack::ann; | ||
using namespace mlpack::optimization; | ||
using namespace mlpack::math; | ||
using namespace mlpack::distribution; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove |
||
|
||
namespace mlpack { | ||
namespace ann /** artifical neural network **/ { | ||
template< | ||
typename Model, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pedantic style issue: incorrect spacing. |
||
typename InitializationRuleType, | ||
class Noise> | ||
class GAN | ||
{ | ||
public: | ||
/** | ||
* Constructor for GAN class | ||
* | ||
* @tparam Model The class type of generator and discriminator. | ||
* @tparam InitializationRuleType Type of Intializer. | ||
* @param generator Generator network. | ||
* @param trainData The real data. | ||
* @param noiseData The data generated from randomly. | ||
* @param discriminator Discriminator network. | ||
* @param batchSize BatchSize to be used for training. | ||
* @param noiseInSize Input size of the generator network. | ||
* @param disIteration Ratio of number of training step for Disc to Gen | ||
*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment is outdated. |
||
GAN(arma::mat& trainData, | ||
Model& generator, | ||
Model& discriminator, | ||
InitializationRuleType initializeRule, | ||
Noise noiseFunction, | ||
size_t noiseDim, | ||
size_t batchSize, | ||
size_t generatorUpdateStep, | ||
size_t preTrainSize); | ||
|
||
// Reset function | ||
void Reset(); | ||
|
||
// Train function | ||
template<typename OptimizerType> | ||
void Train(OptimizerType& Optimizer); | ||
|
||
/** | ||
* Evaluate function for the GAN | ||
* gives the perfomance of the gan | ||
* on the current input. | ||
* | ||
* @param parameters The parameters of the network | ||
* @param i The idx of the current input | ||
*/ | ||
double Evaluate(const arma::mat& parameters, const size_t i); | ||
|
||
/** | ||
* Gradient function for gan. | ||
* This function is passes the gradient based | ||
* on which network is being trained ie generator or Discriminator. | ||
* | ||
* @param parameters present parameters of the network | ||
* @param i index of the predictors | ||
* @param gradient variable to store the present gradient | ||
*/ | ||
void Gradient(const arma::mat& parameters, const size_t i, | ||
arma::mat& gradient); | ||
|
||
/** | ||
* This function does forward pass through the GAN | ||
* network. | ||
* | ||
* @param input Sampled noise | ||
*/ | ||
void Forward(arma::mat&& input); | ||
|
||
/** | ||
* This function predicts the output of the network | ||
* on the given input. | ||
* | ||
* @param input the input the discriminator network | ||
* @param output result of the discriminator network | ||
*/ | ||
void Predict(arma::mat&& input, arma::mat& output); | ||
|
||
//! Return the parameters of the network. | ||
const arma::mat& Parameters() const { return parameter; } | ||
//! Modify the parameters of the network | ||
arma::mat& Parameters() { return parameter; } | ||
|
||
//! Return the number of separable functions (the number of predictor points). | ||
size_t NumFunctions() const { return numFunctions; } | ||
|
||
//! Serialize the model. | ||
template<typename Archive> | ||
void Serialize(Archive& ar, const unsigned int /* version */); | ||
|
||
private: | ||
//! Locally stored parameter for training data. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment is incorrect. |
||
arma::mat predictors; | ||
//! Locally stored parameters of the network. | ||
arma::mat parameter; | ||
|
||
//! Locally stored generator network. | ||
Model& generator; | ||
//! Locally stored discriminator network. | ||
Model& discriminator; | ||
//! Locally stored Intialiser. | ||
InitializationRuleType initializeRule; | ||
//! Locally stored Noise function | ||
Noise noiseFunction; | ||
size_t noiseDim; | ||
//! Locally stored number of data points. | ||
size_t numFunctions; | ||
|
||
//! Locally stored batch size parameter. | ||
size_t batchSize; | ||
|
||
//! Locally stored offset for predictors and noise data. | ||
size_t offset; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we don't use this variable anymore. |
||
//! Locally stored number of iterations that have been completed. | ||
size_t counter; | ||
//! Locally stored batch number which is being processed. | ||
size_t currentBatch; | ||
|
||
size_t generatorUpdateStep; | ||
|
||
size_t preTrainSize; | ||
|
||
//! Locally stored reset parmaeter. | ||
bool reset; | ||
//! Locally stored delta visitor. | ||
DeltaVisitor deltaVisitor; | ||
//! Locally stored responses. | ||
arma::mat responses; | ||
//! Locally stored current input. | ||
arma::mat currentInput; | ||
//! Locally stored current target. | ||
arma::mat currentTarget; | ||
//! Locally-stored output parameter visitor. | ||
OutputParameterVisitor outputParameterVisitor; | ||
//! Locally-stored weight size visitor. | ||
WeightSizeVisitor weightSizeVisitor; | ||
//! Locally-stored reset visitor. | ||
ResetVisitor resetVisitor; | ||
//! Locally stored gradient parameters. | ||
arma::mat gradient; | ||
//! Locally stored gradient for discriminator. | ||
arma::mat gradientDiscriminator; | ||
|
||
arma::mat noiseGradientDiscriminator; | ||
|
||
arma::mat noise; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems the variable is not used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is still used. Look at the evaluate function and the gradient function. |
||
//! Locally stored gradient for generator. | ||
arma::mat gradientGenerator; | ||
//! Locally stored output of the generator network. | ||
arma::mat ganOutput; | ||
}; | ||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "gan_impl.hpp" | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand you need that in order to debug the implementation, but we should remove that in the future.