Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generative Adversarial Network #1066

Closed
wants to merge 57 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
13b5200
Inital Commit
kris-singh Jul 19, 2017
ac7418e
Fix Minor Error
kris-singh Jul 19, 2017
ffd1cf8
Add test
kris-singh Jul 20, 2017
c34883f
Add crossEntropy Konstantin
kris-singh Jul 20, 2017
d5b23c2
FixSegfault and Training
kris-singh Jul 21, 2017
e121b81
Fix the Gradient function and other chnages
kris-singh Jul 23, 2017
2d85c49
Minor style fixes
kris-singh Jul 23, 2017
df48591
Fix reset function
kris-singh Jul 23, 2017
cfaea05
Refactor Train Method
kris-singh Jul 26, 2017
d714edb
Refactor GenerateData
kris-singh Jul 26, 2017
e14e2fe
Random
kris-singh Jul 26, 2017
9f66c33
Fails: Diffrent sized predictors for G & D
kris-singh Jul 31, 2017
cb2ec96
Single Optimizer
kris-singh Aug 1, 2017
f4752ca
Merge branch 'master' into Gan
kris-singh Aug 1, 2017
b4b0239
Fix errors
kris-singh Aug 1, 2017
d14404c
Minor Fixes
kris-singh Aug 2, 2017
a35a445
Merge branch 'openmp_sgd_fix' of https://github.com/zoq/mlpack into Gan
kris-singh Aug 3, 2017
8d872b9
Fix error with disIteration
kris-singh Aug 3, 2017
3a366a1
temp commit
kris-singh Aug 3, 2017
0528c2e
Temp
kris-singh Aug 4, 2017
83df332
Change Intilisation of network
kris-singh Aug 8, 2017
1cd7938
Refactor Code & Delete GenerateNoise Function
kris-singh Aug 9, 2017
87f9e72
Minor Fix
kris-singh Aug 11, 2017
f754852
Refactor GAN.
lozhnikov Aug 12, 2017
6059798
Minor
kris-singh Aug 14, 2017
8c267b6
Add noise function and get rid of batch information
kris-singh Aug 15, 2017
9d32965
Fix commits
kris-singh Aug 15, 2017
a4e9f4b
Merge branch 'master' into Gan
kris-singh Aug 16, 2017
e7ca020
Add preActivationStep & noiseFunction
kris-singh Aug 18, 2017
edcf4d2
Noise function as LambdaFunction or Class
kris-singh Aug 18, 2017
526050b
Intial Commit
kris-singh Aug 18, 2017
df05c04
Minor Fix
kris-singh Aug 20, 2017
4bcf594
Make the layer working
kris-singh Aug 20, 2017
3227197
Style Fix
kris-singh Aug 20, 2017
c6c8da8
Correct Formula
kris-singh Aug 20, 2017
ddc1bd8
Fix indexing in image
kris-singh Aug 20, 2017
1c1ada0
Fix bilinear Function Now Working
kris-singh Aug 20, 2017
1a8b6db
Fix bilinear Function Now Working
kris-singh Aug 20, 2017
8527d34
Generator Optimises -log(D(g(z)) and other minor changes
kris-singh Aug 21, 2017
93ab8b2
Add comments and tests(Temp Commit)
kris-singh Aug 22, 2017
04c51ca
Merge branch 'ResizeLayer' of https://github.com/kris-singh/mlpack in…
kris-singh Aug 22, 2017
ea407ba
Fix HEADER Gaurd
kris-singh Aug 22, 2017
1bae148
Add test
kris-singh Aug 22, 2017
e99b1cf
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 23, 2017
fe58004
Temp
kris-singh Aug 23, 2017
5b0fe4c
Change Everything to 1d Matrices
kris-singh Aug 24, 2017
25cbbd2
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 24, 2017
09d88f4
Fix Tests
kris-singh Aug 24, 2017
5d276ce
Merge branch 'ResizeLayer' into Gan
kris-singh Aug 24, 2017
29d63f5
Add Resize Layer
kris-singh Aug 24, 2017
ea46274
Add test
kris-singh Aug 25, 2017
30b18fd
Add Data
kris-singh Aug 25, 2017
fa619fb
Style Fix
kris-singh Aug 25, 2017
07972dd
Fix depth for bilinear function
kris-singh Aug 25, 2017
1c94976
Style Fixes
kris-singh Aug 25, 2017
32f741b
Style Fix
kris-singh Aug 26, 2017
521bc31
Fix comments and spaces
kris-singh Aug 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mlpack/methods/ann/CMakeLists.txt
Expand Up @@ -3,6 +3,8 @@
set(SOURCES
ffn.hpp
ffn_impl.hpp
gan_impl.hpp
gan.hpp
rnn.hpp
rnn_impl.hpp
)
Expand Down
4 changes: 3 additions & 1 deletion src/mlpack/methods/ann/ffn.hpp
Expand Up @@ -345,11 +345,13 @@ class FFN

//! Locally-stored copy visitor
CopyVisitor copyVisitor;

template<typename Generator, typename Discriminator, typename IntializerType>
friend class GenerativeAdversarialNetwork;
}; // class FFN

} // namespace ann
} // namespace mlpack

// Include implementation.
#include "ffn_impl.hpp"

Expand Down
143 changes: 143 additions & 0 deletions src/mlpack/methods/ann/gan.hpp
@@ -0,0 +1,143 @@
/**
* @file gan.hpp
* @author Kris Singh
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_ANN_GAN_HPP
#define MLPACK_METHODS_ANN_GAN_HPP

#include <mlpack/core.hpp>
#include <mlpack/prereqs.hpp>
#include <mlpack/core/math/random.hpp>

#include <mlpack/methods/ann/layer/layer.hpp>
#include <mlpack/methods/ann/layer/base_layer.hpp>
#include <mlpack/methods/ann/ffn.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct if I am mistaken, but it seems you needn't other includes except this one.

#include <mlpack/methods/ann/visitor/output_parameter_visitor.hpp>


#include <mlpack/methods/ann/activation_functions/softplus_function.hpp>
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp>
#include <mlpack/core/dists/gaussian_distribution.hpp>
#include <mlpack/methods/ann/init_rules/random_init.hpp>

using namespace mlpack;
using namespace mlpack::ann;
using namespace mlpack::optimization;
using namespace mlpack::math;
using namespace mlpack::distribution;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove using namespaces? Usually we don't use them outside of tests.


namespace mlpack {
namespace ann /** Restricted Boltzmann Machine. */ {
template<
typename Generator = FFN<>,
typename Discriminator = FFN<>,
typename IntializerType = RandomInitialization>
class GenerativeAdversarialNetwork
{
public:
GenerativeAdversarialNetwork(arma::mat trainData, arma::mat trainLables,
IntializerType initializeRule,
Generator& generator,
Discriminator& discriminator,
size_t batchSize,
size_t generatorInSize);

void Reset();

template<typename OptimizerType>
void Train(OptimizerType& Optimizer);

double Evaluate(const arma::mat& parameters,
const size_t i,
const bool deterministic = true);
/**
* Gradient function
*/
void Gradient(const arma::mat& parameters, const size_t i,
arma::mat& gradient);

/**
* This function does forward pass through the GAN
* network.
*
* @param input the noise input
*/
void Forward(arma::mat&& input);

/**
* This function predicts the output of the network
* on the given input
*
* @param input the input the discriminator network
* @param output result of the discriminator network
*/
void Predict(arma::mat&& input, arma::mat& output);

/**
* Generate function generates random noise
* samples from a given distribution with
* given args. Samples are stored in a local variable.
*
* @tparam NoiseFunction the distribution to sample from
* @tparam Args the arguments types for args of the distribution
* @param numSamples number of samples to be generated from the distribution
* @param args the aruments of the distribution to samples from
*/
void Generate(size_t numSamples, arma::mat&& fakeData);
//! Return the number of separable functions (the number of predictor points).
size_t NumFunctions() const { return numFunctions; }

private:
//! Locally stored Intialiser
IntializerType initializeRule;
//! Locally stored parameters of the network
arma::mat parameter;
//! Locally stored generator
Generator& generator;
//! Locally stored discriminator
Discriminator& discriminator;
//! Locally stored number of data points
size_t numFunctions;
//! Locally stored trainGenerator parmaeter
bool trainGenerator;
//! Locally stored batch size parameter
size_t batchSize;
//! Locally stored input size for generator
size_t generatorInSize;
//! Locally stored reset parmaeter
bool reset;
//! Locally stored parameter for training data
arma::mat predictors;
//! Locally stored responses
arma::mat responses;
//! Locally stored fake data used for training
arma::mat fakeData;
//! Locally stored fake Labels used for training
arma::mat fakeLables;
//! Locally stored train data comprising of real and fake data
arma::mat tempTrainData;
//! Locally stored temp variable comprisiong of read and fake labels
arma::mat tempLabels;
//! Locally-stored output parameter visitor.
OutputParameterVisitor outputParameterVisitor;
//! Locally stored gradient parameters
arma::mat gradient;
//! Locally stored gradient for discriminator
arma::mat gradientDisriminator;
//! Locally stored gradient for generator
arma::mat gradientGenerator;
//! Locally stored output of the generator network
arma::mat ganOutput;
};
} // namespace ann
} // namespace mlpack

// Include implementation.
#include "gan_impl.hpp"

#endif