Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generative Adversarial Network #1066

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
13b5200
Inital Commit
kris-singh Jul 19, 2017
ac7418e
Fix Minor Error
kris-singh Jul 19, 2017
ffd1cf8
Add test
kris-singh Jul 20, 2017
c34883f
Add crossEntropy Konstantin
kris-singh Jul 20, 2017
d5b23c2
FixSegfault and Training
kris-singh Jul 21, 2017
e121b81
Fix the Gradient function and other chnages
kris-singh Jul 23, 2017
2d85c49
Minor style fixes
kris-singh Jul 23, 2017
df48591
Fix reset function
kris-singh Jul 23, 2017
cfaea05
Refactor Train Method
kris-singh Jul 26, 2017
d714edb
Refactor GenerateData
kris-singh Jul 26, 2017
e14e2fe
Random
kris-singh Jul 26, 2017
9f66c33
Fails: Diffrent sized predictors for G & D
kris-singh Jul 31, 2017
cb2ec96
Single Optimizer
kris-singh Aug 1, 2017
f4752ca
Merge branch 'master' into Gan
kris-singh Aug 1, 2017
b4b0239
Fix errors
kris-singh Aug 1, 2017
d14404c
Minor Fixes
kris-singh Aug 2, 2017
a35a445
Merge branch 'openmp_sgd_fix' of https://github.com/zoq/mlpack into Gan
kris-singh Aug 3, 2017
8d872b9
Fix error with disIteration
kris-singh Aug 3, 2017
3a366a1
temp commit
kris-singh Aug 3, 2017
0528c2e
Temp
kris-singh Aug 4, 2017
83df332
Change Intilisation of network
kris-singh Aug 8, 2017
1cd7938
Refactor Code & Delete GenerateNoise Function
kris-singh Aug 9, 2017
87f9e72
Minor Fix
kris-singh Aug 11, 2017
f754852
Refactor GAN.
lozhnikov Aug 12, 2017
6059798
Minor
kris-singh Aug 14, 2017
8c267b6
Add noise function and get rid of batch information
kris-singh Aug 15, 2017
9d32965
Fix commits
kris-singh Aug 15, 2017
a4e9f4b
Merge branch 'master' into Gan
kris-singh Aug 16, 2017
e7ca020
Add preActivationStep & noiseFunction
kris-singh Aug 18, 2017
edcf4d2
Noise function as LambdaFunction or Class
kris-singh Aug 18, 2017
526050b
Intial Commit
kris-singh Aug 18, 2017
df05c04
Minor Fix
kris-singh Aug 20, 2017
4bcf594
Make the layer working
kris-singh Aug 20, 2017
3227197
Style Fix
kris-singh Aug 20, 2017
c6c8da8
Correct Formula
kris-singh Aug 20, 2017
ddc1bd8
Fix indexing in image
kris-singh Aug 20, 2017
1c1ada0
Fix bilinear Function Now Working
kris-singh Aug 20, 2017
1a8b6db
Fix bilinear Function Now Working
kris-singh Aug 20, 2017
8527d34
Generator Optimises -log(D(g(z)) and other minor changes
kris-singh Aug 21, 2017
93ab8b2
Add comments and tests(Temp Commit)
kris-singh Aug 22, 2017
04c51ca
Merge branch 'ResizeLayer' of https://github.com/kris-singh/mlpack in…
kris-singh Aug 22, 2017
ea407ba
Fix HEADER Gaurd
kris-singh Aug 22, 2017
1bae148
Add test
kris-singh Aug 22, 2017
e99b1cf
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 23, 2017
fe58004
Temp
kris-singh Aug 23, 2017
5b0fe4c
Change Everything to 1d Matrices
kris-singh Aug 24, 2017
25cbbd2
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 24, 2017
09d88f4
Fix Tests
kris-singh Aug 24, 2017
5d276ce
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 24, 2017
29d63f5
Add Resize Layer
kris-singh Aug 24, 2017
ea46274
Add test
kris-singh Aug 25, 2017
30b18fd
Add Data
kris-singh Aug 25, 2017
fa619fb
Style Fix
kris-singh Aug 25, 2017
07972dd
Fix depth for bilinear function
kris-singh Aug 25, 2017
1c94976
Style Fixes
kris-singh Aug 25, 2017
32f741b
Style Fix
kris-singh Aug 26, 2017
521bc31
Fix comments and spaces
kris-singh Aug 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mlpack/core/optimizers/minibatch_sgd/minibatch_sgd.hpp
Expand Up @@ -15,6 +15,7 @@
#include <mlpack/prereqs.hpp>
#include <mlpack/core/optimizers/sgd/update_policies/vanilla_update.hpp>
#include <mlpack/core/optimizers/minibatch_sgd/decay_policies/no_decay.hpp>
#include <mlpack/core/optimizers/adam/adam_update.hpp>

namespace mlpack {
namespace optimization {
Expand Down Expand Up @@ -204,6 +205,7 @@ class MiniBatchSGDType
};

using MiniBatchSGD = MiniBatchSGDType<VanillaUpdate, NoDecay>;
using AdamBatchSGD = MiniBatchSGDType<AdamUpdate, NoDecay>;

} // namespace optimization
} // namespace mlpack
Expand Down
Expand Up @@ -92,6 +92,10 @@ double MiniBatchSGDType<
Log::Info << "Mini-batch SGD: iteration " << i << ", objective "
<< overallObjective << "." << std::endl;

Log::Info << "Last objective" << lastObjective << std::endl;

Log::Info << "Gradient = " << arma::max(gradient) << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you need that in order to debug the implementation, but we should remove that in the future.


if (std::isnan(overallObjective) || std::isinf(overallObjective))
{
Log::Warn << "Mini-batch SGD: converged to " << overallObjective
Expand Down Expand Up @@ -127,7 +131,6 @@ double MiniBatchSGDType<
function.Gradient(iterate, offset + j, funcGradient);
gradient += funcGradient;
}

// Now update the iterate.
updatePolicy.Update(iterate, stepSize / batchSize, gradient);

Expand Down
3 changes: 3 additions & 0 deletions src/mlpack/methods/ann/CMakeLists.txt
Expand Up @@ -3,6 +3,8 @@
set(SOURCES
ffn.hpp
ffn_impl.hpp
gan_impl.hpp
gan.hpp
rnn.hpp
rnn_impl.hpp
)
Expand All @@ -22,3 +24,4 @@ add_subdirectory(init_rules)
add_subdirectory(layer)
add_subdirectory(convolution_rules)
add_subdirectory(augmented)
add_subdirectory(image_functions)
4 changes: 3 additions & 1 deletion src/mlpack/methods/ann/ffn.hpp
Expand Up @@ -360,11 +360,13 @@ class FFN

//! Locally-stored copy visitor
CopyVisitor copyVisitor;

template<typename Model, typename IntializerType, class NoiseType>
friend class GAN;
}; // class FFN

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "ffn_impl.hpp"

Expand Down
187 changes: 187 additions & 0 deletions src/mlpack/methods/ann/gan.hpp
@@ -0,0 +1,187 @@
/**
* @file gan.hpp
* @author Kris Singh
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_GAN_HPP
#define MLPACK_METHODS_ANN_GAN_HPP

#include <mlpack/core.hpp>

#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/layer/base_layer.hpp>
#include <mlpack/methods/ann/ffn.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct if I am mistaken, but it seems you needn't other includes except this one.

#include <mlpack/methods/ann/visitor/output_parameter_visitor.hpp>
#include <mlpack/methods/ann/visitor/reset_visitor.hpp>
#include <mlpack/methods/ann/visitor/weight_size_visitor.hpp>
#include <mlpack/methods/ann/visitor/weight_set_visitor.hpp>
#include <mlpack/methods/ann/init_rules/random_init.hpp>

using namespace mlpack;
using namespace mlpack::ann;
using namespace mlpack::optimization;
using namespace mlpack::math;
using namespace mlpack::distribution;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove using namespaces? Usually we don't use them outside of tests.


namespace mlpack {
namespace ann /** artifical neural network **/ {
template<
typename Model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pedantic style issue: incorrect spacing.

typename InitializationRuleType,
class Noise>
class GAN
{
public:
/**
* Constructor for GAN class
*
* @tparam Model The class type of generator and discriminator.
* @tparam InitializationRuleType Type of Intializer.
* @param generator Generator network.
* @param trainData The real data.
* @param noiseData The data generated from randomly.
* @param discriminator Discriminator network.
* @param batchSize BatchSize to be used for training.
* @param noiseInSize Input size of the generator network.
* @param disIteration Ratio of number of training step for Disc to Gen
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is outdated.

GAN(arma::mat& trainData,
Model& generator,
Model& discriminator,
InitializationRuleType initializeRule,
Noise noiseFunction,
size_t noiseDim,
size_t batchSize,
size_t generatorUpdateStep,
size_t preTrainSize);

// Reset function
void Reset();

// Train function
template<typename OptimizerType>
void Train(OptimizerType& Optimizer);

/**
* Evaluate function for the GAN
* gives the perfomance of the gan
* on the current input.
*
* @param parameters The parameters of the network
* @param i The idx of the current input
*/
double Evaluate(const arma::mat& parameters, const size_t i);

/**
* Gradient function for gan.
* This function is passes the gradient based
* on which network is being trained ie generator or Discriminator.
*
* @param parameters present parameters of the network
* @param i index of the predictors
* @param gradient variable to store the present gradient
*/
void Gradient(const arma::mat& parameters, const size_t i,
arma::mat& gradient);

/**
* This function does forward pass through the GAN
* network.
*
* @param input Sampled noise
*/
void Forward(arma::mat&& input);

/**
* This function predicts the output of the network
* on the given input.
*
* @param input the input the discriminator network
* @param output result of the discriminator network
*/
void Predict(arma::mat&& input, arma::mat& output);

//! Return the parameters of the network.
const arma::mat& Parameters() const { return parameter; }
//! Modify the parameters of the network
arma::mat& Parameters() { return parameter; }

//! Return the number of separable functions (the number of predictor points).
size_t NumFunctions() const { return numFunctions; }

//! Serialize the model.
template<typename Archive>
void Serialize(Archive& ar, const unsigned int /* version */);

private:
//! Locally stored parameter for training data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is incorrect.

arma::mat predictors;
//! Locally stored parameters of the network.
arma::mat parameter;

//! Locally stored generator network.
Model& generator;
//! Locally stored discriminator network.
Model& discriminator;
//! Locally stored Intialiser.
InitializationRuleType initializeRule;
//! Locally stored Noise function
Noise noiseFunction;
size_t noiseDim;
//! Locally stored number of data points.
size_t numFunctions;

//! Locally stored batch size parameter.
size_t batchSize;

//! Locally stored offset for predictors and noise data.
size_t offset;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we don't use this variable anymore.

//! Locally stored number of iterations that have been completed.
size_t counter;
//! Locally stored batch number which is being processed.
size_t currentBatch;

size_t generatorUpdateStep;

size_t preTrainSize;

//! Locally stored reset parmaeter.
bool reset;
//! Locally stored delta visitor.
DeltaVisitor deltaVisitor;
//! Locally stored responses.
arma::mat responses;
//! Locally stored current input.
arma::mat currentInput;
//! Locally stored current target.
arma::mat currentTarget;
//! Locally-stored output parameter visitor.
OutputParameterVisitor outputParameterVisitor;
//! Locally-stored weight size visitor.
WeightSizeVisitor weightSizeVisitor;
//! Locally-stored reset visitor.
ResetVisitor resetVisitor;
//! Locally stored gradient parameters.
arma::mat gradient;
//! Locally stored gradient for discriminator.
arma::mat gradientDiscriminator;

arma::mat noiseGradientDiscriminator;

arma::mat noise;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the variable is not used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still used. Look at the evaluate function and the gradient function.

//! Locally stored gradient for generator.
arma::mat gradientGenerator;
//! Locally stored output of the generator network.
arma::mat ganOutput;
};
} // namespace ann
} // namespace mlpack

// Include implementation.
#include "gan_impl.hpp"

#endif