New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generative Adversarial Network #1066
Closed
Closed
Changes from 4 commits
Commits
Show all changes
57 commits
Select commit
Hold shift + click to select a range
13b5200
Inital Commit
kris-singh ac7418e
Fix Minor Error
kris-singh ffd1cf8
Add test
kris-singh c34883f
Add crossEntropy Konstantin
kris-singh d5b23c2
FixSegfault and Training
kris-singh e121b81
Fix the Gradient function and other chnages
kris-singh 2d85c49
Minor style fixes
kris-singh df48591
Fix reset function
kris-singh cfaea05
Refactor Train Method
kris-singh d714edb
Refactor GenerateData
kris-singh e14e2fe
Random
kris-singh 9f66c33
Fails: Diffrent sized predictors for G & D
kris-singh cb2ec96
Single Optimizer
kris-singh f4752ca
Merge branch 'master' into Gan
kris-singh b4b0239
Fix errors
kris-singh d14404c
Minor Fixes
kris-singh a35a445
Merge branch 'openmp_sgd_fix' of https://github.com/zoq/mlpack into Gan
kris-singh 8d872b9
Fix error with disIteration
kris-singh 3a366a1
temp commit
kris-singh 0528c2e
Temp
kris-singh 83df332
Change Intilisation of network
kris-singh 1cd7938
Refactor Code & Delete GenerateNoise Function
kris-singh 87f9e72
Minor Fix
kris-singh f754852
Refactor GAN.
lozhnikov 6059798
Minor
kris-singh 8c267b6
Add noise function and get rid of batch information
kris-singh 9d32965
Fix commits
kris-singh a4e9f4b
Merge branch 'master' into Gan
kris-singh e7ca020
Add preActivationStep & noiseFunction
kris-singh edcf4d2
Noise function as LambdaFunction or Class
kris-singh 526050b
Intial Commit
kris-singh df05c04
Minor Fix
kris-singh 4bcf594
Make the layer working
kris-singh 3227197
Style Fix
kris-singh c6c8da8
Correct Formula
kris-singh ddc1bd8
Fix indexing in image
kris-singh 1c1ada0
Fix bilinear Function Now Working
kris-singh 1a8b6db
Fix bilinear Function Now Working
kris-singh 8527d34
Generator Optimises -log(D(g(z)) and other minor changes
kris-singh 93ab8b2
Add comments and tests(Temp Commit)
kris-singh 04c51ca
Merge branch 'ResizeLayer' of https://github.com/kris-singh/mlpack in…
kris-singh ea407ba
Fix HEADER Gaurd
kris-singh 1bae148
Add test
kris-singh e99b1cf
Merge branch 'ResizeLayer' into Gan
kris-singh fe58004
Temp
kris-singh 5b0fe4c
Change Everything to 1d Matrices
kris-singh 25cbbd2
Merge branch 'ResizeLayer' into Gan
kris-singh 09d88f4
Fix Tests
kris-singh 5d276ce
Merge branch 'ResizeLayer' into Gan
kris-singh 29d63f5
Add Resize Layer
kris-singh ea46274
Add test
kris-singh 30b18fd
Add Data
kris-singh fa619fb
Style Fix
kris-singh 07972dd
Fix depth for bilinear function
kris-singh 1c94976
Style Fixes
kris-singh 32f741b
Style Fix
kris-singh 521bc31
Fix comments and spaces
kris-singh File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
set(SOURCES | ||
ffn.hpp | ||
ffn_impl.hpp | ||
gan_impl.hpp | ||
gan.hpp | ||
rnn.hpp | ||
rnn_impl.hpp | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/** | ||
* @file gan.hpp | ||
* @author Kris Singh | ||
* | ||
* mlpack is free software; you may redistribute it and/or modify it under the | ||
* terms of the 3-clause BSD license. You should have received a copy of the | ||
* 3-clause BSD license along with mlpack. If not, see | ||
* http://www.opensource.org/licenses/BSD-3-Clause for more information. | ||
*/ | ||
#ifndef MLPACK_METHODS_ANN_GAN_HPP | ||
#define MLPACK_METHODS_ANN_GAN_HPP | ||
|
||
#include <mlpack/core.hpp> | ||
#include <mlpack/prereqs.hpp> | ||
#include <mlpack/core/math/random.hpp> | ||
|
||
#include <mlpack/methods/ann/layer/layer.hpp> | ||
#include <mlpack/methods/ann/layer/base_layer.hpp> | ||
#include <mlpack/methods/ann/ffn.hpp> | ||
#include <mlpack/methods/ann/visitor/output_parameter_visitor.hpp> | ||
|
||
|
||
#include <mlpack/methods/ann/activation_functions/softplus_function.hpp> | ||
#include <mlpack/methods/ann/init_rules/gaussian_init.hpp> | ||
#include <mlpack/core/dists/gaussian_distribution.hpp> | ||
#include <mlpack/methods/ann/init_rules/random_init.hpp> | ||
|
||
using namespace mlpack; | ||
using namespace mlpack::ann; | ||
using namespace mlpack::optimization; | ||
using namespace mlpack::math; | ||
using namespace mlpack::distribution; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove |
||
|
||
namespace mlpack { | ||
namespace ann /** Restricted Boltzmann Machine. */ { | ||
template< | ||
typename Generator = FFN<>, | ||
typename Discriminator = FFN<>, | ||
typename IntializerType = RandomInitialization> | ||
class GenerativeAdversarialNetwork | ||
{ | ||
public: | ||
GenerativeAdversarialNetwork(arma::mat trainData, arma::mat trainLables, | ||
IntializerType initializeRule, | ||
Generator& generator, | ||
Discriminator& discriminator, | ||
size_t batchSize, | ||
size_t generatorInSize); | ||
|
||
void Reset(); | ||
|
||
template<typename OptimizerType> | ||
void Train(OptimizerType& Optimizer); | ||
|
||
double Evaluate(const arma::mat& parameters, | ||
const size_t i, | ||
const bool deterministic = true); | ||
/** | ||
* Gradient function | ||
*/ | ||
void Gradient(const arma::mat& parameters, const size_t i, | ||
arma::mat& gradient); | ||
|
||
/** | ||
* This function does forward pass through the GAN | ||
* network. | ||
* | ||
* @param input the noise input | ||
*/ | ||
void Forward(arma::mat&& input); | ||
|
||
/** | ||
* This function predicts the output of the network | ||
* on the given input | ||
* | ||
* @param input the input the discriminator network | ||
* @param output result of the discriminator network | ||
*/ | ||
void Predict(arma::mat&& input, arma::mat& output); | ||
|
||
/** | ||
* Generate function generates random noise | ||
* samples from a given distribution with | ||
* given args. Samples are stored in a local variable. | ||
* | ||
* @tparam NoiseFunction the distribution to sample from | ||
* @tparam Args the arguments types for args of the distribution | ||
* @param numSamples number of samples to be generated from the distribution | ||
* @param args the aruments of the distribution to samples from | ||
*/ | ||
void Generate(size_t numSamples, arma::mat&& fakeData); | ||
//! Return the number of separable functions (the number of predictor points). | ||
size_t NumFunctions() const { return numFunctions; } | ||
|
||
private: | ||
//! Locally stored Intialiser | ||
IntializerType initializeRule; | ||
//! Locally stored parameters of the network | ||
arma::mat parameter; | ||
//! Locally stored generator | ||
Generator& generator; | ||
//! Locally stored discriminator | ||
Discriminator& discriminator; | ||
//! Locally stored number of data points | ||
size_t numFunctions; | ||
//! Locally stored trainGenerator parmaeter | ||
bool trainGenerator; | ||
//! Locally stored batch size parameter | ||
size_t batchSize; | ||
//! Locally stored input size for generator | ||
size_t generatorInSize; | ||
//! Locally stored reset parmaeter | ||
bool reset; | ||
//! Locally stored parameter for training data | ||
arma::mat predictors; | ||
//! Locally stored responses | ||
arma::mat responses; | ||
//! Locally stored fake data used for training | ||
arma::mat fakeData; | ||
//! Locally stored fake Labels used for training | ||
arma::mat fakeLables; | ||
//! Locally stored train data comprising of real and fake data | ||
arma::mat tempTrainData; | ||
//! Locally stored temp variable comprisiong of read and fake labels | ||
arma::mat tempLabels; | ||
//! Locally-stored output parameter visitor. | ||
OutputParameterVisitor outputParameterVisitor; | ||
//! Locally stored gradient parameters | ||
arma::mat gradient; | ||
//! Locally stored gradient for discriminator | ||
arma::mat gradientDisriminator; | ||
//! Locally stored gradient for generator | ||
arma::mat gradientGenerator; | ||
//! Locally stored output of the generator network | ||
arma::mat ganOutput; | ||
}; | ||
} // namespace ann | ||
} // namespace mlpack | ||
|
||
// Include implementation. | ||
#include "gan_impl.hpp" | ||
|
||
#endif |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct if I am mistaken, but it seems you needn't other includes except this one.