New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic DQN #1014
Basic DQN #1014
Changes from 8 commits
1438cb0
4941e66
b844b04
46bf8e3
b9b1d8e
8aa4e03
5baa6da
d15706a
0e5dc3c
2bded68
50a6972
ffb45aa
f677299
7c37253
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,15 +51,18 @@ class FFN | |
|
||
/** | ||
* Create the FFN object with the given predictors and responses set (this is | ||
* the set that is used to train the network) and the given optimizer. | ||
* the set that is used to train the network). | ||
* Optionally, specify which initialize rule and performance function should | ||
* be used. | ||
* | ||
* If you want to pass in a parameter and discard the original parameter | ||
* object, be sure to use std::move to avoid unnecessary copy. | ||
* | ||
* @param outputLayer Output layer used to evaluate the network. | ||
* @param initializeRule Optional instantiated InitializationRule object | ||
* for initializing the network parameter. | ||
*/ | ||
FFN(OutputLayerType&& outputLayer = OutputLayerType(), | ||
FFN(OutputLayerType outputLayer = OutputLayerType(), | ||
InitializationRuleType initializeRule = InitializationRuleType()); | ||
|
||
//! Copy constructor. | ||
|
@@ -73,19 +76,22 @@ class FFN | |
|
||
/** | ||
* Create the FFN object with the given predictors and responses set (this is | ||
* the set that is used to train the network) and the given optimizer. | ||
* the set that is used to train the network). | ||
* Optionally, specify which initialize rule and performance function should | ||
* be used. | ||
* | ||
* If you want to pass in a parameter and discard the original parameter | ||
* object, be sure to use std::move to avoid unnecessary copy. | ||
* | ||
* @param predictors Input training variables. | ||
* @param responses Outputs results from input training variables. | ||
* @param outputLayer Output layer used to evaluate the network. | ||
* @param initializeRule Optional instantiated InitializationRule object | ||
* for initializing the network parameter. | ||
*/ | ||
FFN(const arma::mat& predictors, | ||
const arma::mat& responses, | ||
OutputLayerType&& outputLayer = OutputLayerType(), | ||
FFN(arma::mat predictors, | ||
arma::mat responses, | ||
OutputLayerType outputLayer = OutputLayerType(), | ||
InitializationRuleType initializeRule = InitializationRuleType()); | ||
|
||
//! Destructor to release allocated memory. | ||
|
@@ -99,6 +105,9 @@ class FFN | |
* optimization. If this is not what you want, then you should access the | ||
* parameters vector directly with Parameters() and modify it as desired. | ||
* | ||
* If you want to pass in a parameter and discard the original parameter | ||
* object, be sure to use std::move to avoid unnecessary copy. | ||
* | ||
* @tparam OptimizerType Type of optimizer to use to train the model. | ||
* @param predictors Input training variables. | ||
* @param responses Outputs results from input training variables. | ||
|
@@ -109,8 +118,8 @@ class FFN | |
mlpack::optimization::RMSProp, | ||
typename... OptimizerTypeArgs | ||
> | ||
void Train(const arma::mat& predictors, | ||
const arma::mat& responses, | ||
void Train(arma::mat predictors, | ||
arma::mat responses, | ||
OptimizerType<NetworkType, OptimizerTypeArgs...>& optimizer); | ||
|
||
/** | ||
|
@@ -122,24 +131,30 @@ class FFN | |
* optimization. If this is not what you want, then you should access the | ||
* parameters vector directly with Parameters() and modify it as desired. | ||
* | ||
* If you want to pass in a parameter and discard the original parameter | ||
* object, be sure to use std::move to avoid unnecessary copy. | ||
* | ||
* @tparam OptimizerType Type of optimizer to use to train the model. | ||
* @param predictors Input training variables. | ||
* @param responses Outputs results from input training variables. | ||
*/ | ||
template< | ||
template<typename...> class OptimizerType = mlpack::optimization::RMSProp | ||
> | ||
void Train(const arma::mat& predictors, const arma::mat& responses); | ||
void Train(arma::mat predictors, arma::mat responses); | ||
|
||
/** | ||
* Predict the responses to a given set of predictors. The responses will | ||
* reflect the output of the given output layer as returned by the | ||
* output layer function. | ||
* | ||
* If you want to pass in a parameter and discard the original parameter | ||
* object, be sure to use std::move to avoid unnecessary copy. | ||
* | ||
* @param predictors Input predictors. | ||
* @param results Matrix to put output predictions of responses into. | ||
*/ | ||
void Predict(const arma::mat& predictors, arma::mat& results); | ||
void Predict(arma::mat predictors, arma::mat& results); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Original There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good for me, I think we should use a reference here, there might be compilers which are not going to optimize this way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But returning a non-const reference to a local variable is an undefined behavior. And I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was talking about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh. The reason why I didn't use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok that makes sense. Then I think what we need to do here is to implement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like a good plan for me. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So will you do this or I make it in this PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you like you can do here, don't feel obligated, just let me know. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I can do it here |
||
/** | ||
* Evaluate the feedforward network with the given parameters. This function | ||
|
@@ -226,7 +241,7 @@ class FFN | |
* @param predictors Input data variables. | ||
* @param responses Outputs results from input data variables. | ||
*/ | ||
void ResetData(const arma::mat& predictors, const arma::mat& responses); | ||
void ResetData(arma::mat predictors, arma::mat responses); | ||
|
||
/** | ||
* The Backward algorithm (part of the Forward-Backward algorithm). Computes | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a question more for @zoq: what is the intended use of this constructor and the version that takes
const arma::mat&
s? I don't see it used in this code, and I'm not sure why a user would want to initialize the network with some data but not train it. Maybe there is some reason I haven't thought of. :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I can't think of any situation where that might be necessary. Maybe I just looked at the LinearRegression/LogisticRegression class and thought it's a good idea.