Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of q_networks #2317

Merged
merged 12 commits into from Apr 7, 2020
Merged

Addition of q_networks #2317

merged 12 commits into from Apr 7, 2020

Conversation

nishantkr18
Copy link
Member

@nishantkr18 nishantkr18 commented Mar 19, 2020

Hi! I would like to suggest some changes in the structure of DQN structure. Instead of passing FFN model directly into agent, we can pass it to a QNetwork, and that in turn will be passed to the agent. This would have two advantages:

  1. It will be easier to add duelingDQN; just by adding two extra layers for Value and Advantage functions and making Predict, Forward, Backward and Update functions. We could use the same q_learning for vanillaDQN, DoubleDQN and DuelingDQN.
  2. We may pass a default FFN as model in agent for DQN.
    The tests would then change from
  FFN<MeanSquaredError<>, GaussianInitialization> model(MeanSquaredError<>(),
      GaussianInitialization(0, 0.001));
  model.Add<Linear<>>(4, 128);
  model.Add<ReLULayer<>>();
  model.Add<Linear<>>(128, 128);
  model.Add<ReLULayer<>>();
  model.Add<Linear<>>(128, 2);

  // Set up the policy and replay method.
  GreedyPolicy<CartPole> policy(1.0, 1000, 0.1, 0.99);
  RandomReplay<CartPole> replayMethod(10, 10000);

  TrainingConfig config;
  config.StepSize() = 0.01;
  config.Discount() = 0.9;
  config.TargetNetworkSyncInterval() = 100;
  config.ExplorationSteps() = 100;
  config.DoubleQLearning() = false;
  config.StepLimit() = 200;

  // Set up DQN agent.
  QLearning<CartPole, decltype(model), AdamUpdate, decltype(policy)>
      agent(std::move(config), std::move(model), std::move(policy),
      std::move(replayMethod));

to

    // Set up the policy and replay method.
    GreedyPolicy<CartPole> policy(1.0, 1000, 0.1, 0.99);
    RandomReplay<CartPole> replayMethod(10, 10000);
    QNetwork<> network();

    TrainingConfig config;
    config.StepSize() = 0.01;
    config.Discount() = 0.9;
    config.TargetNetworkSyncInterval() = 100;
    config.ExplorationSteps() = 100;
    config.DoubleQLearning() = false;
    config.StepLimit() = 200;

    // Set up DQN agent.
    QLearning<CartPole, decltype(network), AdamUpdate, decltype(policy)>
        agent(std::move(config), std::move(network), std::move(policy),
              std::move(replayMethod));

I have added the Qnetwork file without proper documentation and completion as of now, just to get some reviews and suggestions..

Copy link
Contributor

@bisakhmondal bisakhmondal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @nishantkr18, just some minor changes. Nice work 👍

src/mlpack/methods/reinforcement_learning/q_network.hpp Outdated Show resolved Hide resolved
@zoq zoq removed the s: unanswered label Mar 20, 2020
@nishantkr18 nishantkr18 changed the title Addition of q_network [WIP] Addition of q_network Mar 21, 2020
@kartikdutt18
Copy link
Member

Hey @nishantkr18, The macOS failure seems unrelated. I think if you rebased it should be fixed. Thanks a lot. 👍

@nishantkr18 nishantkr18 requested review from zoq and birm March 27, 2020 11:48
@nishantkr18 nishantkr18 changed the title [WIP] Addition of q_network Addition of q_networks Mar 29, 2020
@nishantkr18
Copy link
Member Author

nishantkr18 commented Mar 29, 2020

I have created a new folder for q_networks which currently contains the simple DQN by the name vanillaDQN. I'll add the other types of Q networks in a separate PR. Would that be fine?

Please have a look.

Copy link
Member

@zoq zoq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice!

model.Add<Linear<>>(64, 32);
model.Add<ReLULayer<>>();
model.Add<Linear<>>(32, 3);
VanillaDQN<> model(4, 64, 32, 3);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should show in a single test, it's also possible to manually specifiy the network, completely without using VanillaDQN?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think u mean this:

    FFN<MeanSquaredError<>, GaussianInitialization>
        network(MeanSquaredError<>(),
                GaussianInitialization(0, 0.001));
    network.Add<Linear<>>(6, 256);
    network.Add<ReLULayer<>>();
    network.Add<Linear<>>(256, 3);

    // Create custom network type
    VanillaDQN<decltype(network)> model(std::move(network));

This is taken from the test for DoublePoleCartWithDQN, showing how we can manually specify the network. But I could add a separate test for CartPole with DQN as well, showing how to manually specify.?

I'm afraid we wont be able to directly pass the network to QLearning(ie without using VanillaDQN), as I've used the method ResetParametersIfEmpty() in QLearning which is not present in FNN. Would that be fine?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could rename ResetParametersIfEmpty to ResetParameters and revert the change, that way we can still pass a vanilla network, maybe I missed something? There is already a test that uses the copy constructor, but would be great if we could provide backward compatibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you are right.. Actually I didn't notice that Parameters() had already been added to SimpleDQN..
Anyways, now I've done the necessary changes.. Kindly have a look..

Copy link
Member

@birm birm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work incorporating the changes!

@sriramsk1999
Copy link
Contributor

Hi @nishantkr18 , I was having a look at the code and I had a few questions. Wouldn't structuring the network this way be too restrictive? As in a simpleDQN would always be a two-layer network with Linear and ReLu?

What would the process be if I wanted to use a DQN of a different architecture, say 3-layer Convolution with LeakyReLU? I hope my question is clear :)

@nishantkr18
Copy link
Member Author

Hi @nishantkr18 , I was having a look at the code and I had a few questions. Wouldn't structuring the network this way be too restrictive? As in a simpleDQN would always be a two-layer network with Linear and ReLu?

What would the process be if I wanted to use a DQN of a different architecture, say 3-layer Convolution with LeakyReLU? I hope my question is clear :)

Hi @sriramsk1999 ! One can easily pass custom network architectures directly into QLearning or via SimpleDQN, as is done here.
The purpose of q_networks is that when other extensions are added to DQN, whose network structures are different from each other, instead of creating the entire network structure in the tests itself, one can easily use the preexisting networks from q_networks. But again, custom network support is still available..
I hope that makes it clear. Let me know if there are any other questions :)

@sriramsk1999
Copy link
Contributor

Ah okay, I think I understand now. I thought that q_networks was supposed to replace the existing way to add a network to the QLearning agent.

If I understood you correctly, it is supplementary to the existing method and acts as a shortcut when using the commonly used architectures. Thanks for the clarification and nice work. :)

Copy link
Member

@zoq zoq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this together, no more comments from my side.

@birm birm merged commit 3382b72 into mlpack:master Apr 7, 2020
@nishantkr18 nishantkr18 deleted the duelingDQN branch April 7, 2020 19:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants