New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of q_networks #2317
Addition of q_networks #2317
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @nishantkr18, just some minor changes. Nice work 👍
Hey @nishantkr18, The macOS failure seems unrelated. I think if you rebased it should be fixed. Thanks a lot. 👍 |
This reverts commit 5786015.
I have created a new folder for q_networks which currently contains the simple DQN by the name Please have a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks nice!
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
src/mlpack/methods/reinforcement_learning/q_networks/vanilla_dqn.hpp
Outdated
Show resolved
Hide resolved
model.Add<Linear<>>(64, 32); | ||
model.Add<ReLULayer<>>(); | ||
model.Add<Linear<>>(32, 3); | ||
VanillaDQN<> model(4, 64, 32, 3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should show in a single test, it's also possible to manually specifiy the network, completely without using VanillaDQN
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I think u mean this:
FFN<MeanSquaredError<>, GaussianInitialization>
network(MeanSquaredError<>(),
GaussianInitialization(0, 0.001));
network.Add<Linear<>>(6, 256);
network.Add<ReLULayer<>>();
network.Add<Linear<>>(256, 3);
// Create custom network type
VanillaDQN<decltype(network)> model(std::move(network));
This is taken from the test for DoublePoleCartWithDQN
, showing how we can manually specify the network. But I could add a separate test for CartPole with DQN as well, showing how to manually specify.?
I'm afraid we wont be able to directly pass the network to QLearning
(ie without using VanillaDQN
), as I've used the method ResetParametersIfEmpty()
in QLearning which is not present in FNN. Would that be fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could rename ResetParametersIfEmpty
to ResetParameters
and revert the change, that way we can still pass a vanilla network, maybe I missed something? There is already a test that uses the copy constructor, but would be great if we could provide backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, you are right.. Actually I didn't notice that Parameters()
had already been added to SimpleDQN
..
Anyways, now I've done the necessary changes.. Kindly have a look..
…qn.hpp Co-Authored-By: Marcus Edel <marcus.edel@fu-berlin.de>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work incorporating the changes!
Hi @nishantkr18 , I was having a look at the code and I had a few questions. Wouldn't structuring the network this way be too restrictive? As in a What would the process be if I wanted to use a DQN of a different architecture, say 3-layer |
Hi @sriramsk1999 ! One can easily pass custom network architectures directly into |
Ah okay, I think I understand now. I thought that If I understood you correctly, it is supplementary to the existing method and acts as a shortcut when using the commonly used architectures. Thanks for the clarification and nice work. :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting this together, no more comments from my side.
Hi! I would like to suggest some changes in the structure of DQN structure. Instead of passing FFN model directly into agent, we can pass it to a QNetwork, and that in turn will be passed to the agent. This would have two advantages:
The tests would then change from
to
I have added the Qnetwork file without proper documentation and completion as of now, just to get some reviews and suggestions..