-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
trpo_try #567
trpo_try #567
Conversation
@@ -1,7 +1,7 @@ | |||
# Copyright (c) Microsoft Corporation. | |||
# Licensed under the MIT license. | |||
|
|||
from .rl_component_bundle import rl_component_bundle | |||
from rl_component_bundle import rl_component_bundle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this change, otherwise the run_rl_example.py
won't work.
@@ -8,7 +8,7 @@ | |||
from maro.rl.rollout import AbsEnvSampler, CacheElement | |||
from maro.simulator.scenarios.cim.common import Action, ActionType, DecisionEvent | |||
|
|||
from .config import action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes | |||
from config import action_shaping_conf, port_attributes, reward_shaping_conf, state_shaping_conf, vessel_attributes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this change, otherwise the run_rl_example.py
won't work.
from .algorithms.ppo import get_ppo, get_ppo_policy | ||
from examples.cim.rl.config import action_num, algorithm, env_conf, reward_shaping_conf, state_dim | ||
from examples.cim.rl.env_sampler import CIMEnvSampler | ||
from algorithms.ac import get_ac, get_ac_policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert this change, otherwise the run_rl_example.py
won't work.
@@ -383,3 +383,5 @@ def to_device(self, device: torch.device) -> None: | |||
def _to_device_impl(self, device: torch.device) -> None: | |||
"""Implementation of `to_device`.""" | |||
raise NotImplementedError | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unnecessary blank lines (you may run pre-commit run --all
to do auto-formatting).
@@ -6,6 +6,7 @@ | |||
from .dqn import DQNParams, DQNTrainer | |||
from .maddpg import DiscreteMADDPGParams, DiscreteMADDPGTrainer | |||
from .ppo import DiscretePPOWithEntropyTrainer, PPOParams, PPOTrainer | |||
from .trpo import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not use from XXX import *
since it is ambiguous.
# mask -> actions个1 | ||
batch = self._get_batch() | ||
# trpo_main.update_params(batch) | ||
for _ in range(self._params.grad_iters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to this pseudo code, we should update actor first, then update critic?
Args: | ||
batch (TransitionBatch): Batch. | ||
""" | ||
self._v_critic_net.step(self._get_critic_loss(batch)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call self._v_critic_net.train()
before updating critic net.
""" | ||
loss = self._get_actor_loss(batch) | ||
|
||
self._policy.train_step(loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call self._policy.train()
before updating policy.
def _get_actor_loss(self, batch: TransitionBatch): | ||
assert isinstance(self._policy, DiscretePolicyGradient) or isinstance(self._policy, ContinuousRLPolicy) | ||
self._policy.train() | ||
rewards = ndarray_to_tensor(batch.rewards) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specify device in ndarray_to_tensor
.
prev_value = 0 | ||
prev_advantage = 0 | ||
|
||
for i in reversed(range(rewards.size(0))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be done in preprocess_batch
?
Description
Linked issue(s)/Pull request(s)
Type of Change
Related Component
Has Been Tested
Needs Follow Up Actions
Checklist