Skip to content

Commit

Permalink
Code check
Browse files Browse the repository at this point in the history
  • Loading branch information
buptchan committed Aug 5, 2021
1 parent 887b90b commit 8402253
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion maro/rl/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from .ac import ActorCritic
from .ddpg import DDPG
from .dqn import DQN
from .pg import PolicyGradient
from .index import get_algorithm_cls, get_algorithm_model_cls
from .pg import PolicyGradient

__all__ = [
"AbsAlgorithm", "ActorCritic", "DDPG", "DQN", "PolicyGradient", "get_algorithm_cls", "get_algorithm_model_cls"
Expand Down
3 changes: 1 addition & 2 deletions maro/rl/algorithms/abs_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT license.

from abc import ABC, abstractmethod
from typing import Union

from maro.rl.experience import ExperienceSet
from maro.rl.exploration import AbsExploration
Expand Down Expand Up @@ -50,7 +49,7 @@ def get_state(self, inference: bool = True):
Args:
learning (bool): If True, the returned state is for inference purpose only. This parameter
may be ignored for some algorithms.
may be ignored for some algorithms.
"""
pass

Expand Down
6 changes: 3 additions & 3 deletions maro/rl/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def _get_loss(self, batch: ExperienceSet):

def learn(self, data: Union[ExperienceSet, dict]):
assert self.ac_net.trainable, "ac_net needs to have at least one optimizer registered."
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
if isinstance(data, ExperienceSet):
self.ac_net.train()
self.ac_net.train()
loss = self._get_loss(data)
self.ac_net.step(loss)
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
else:
self.ac_net.apply(data)

Expand Down
4 changes: 2 additions & 2 deletions maro/rl/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def _get_loss(self, batch: ExperienceSet):

def learn(self, data: Union[ExperienceSet, dict]):
assert self.ac_net.trainable, "ac_net needs to have at least one optimizer registered."
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
if isinstance(data, ExperienceSet):
self.ac_net.train()
loss = self._get_loss(data)
self.ac_net.step(loss)
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
else:
self.ac_net.apply(data)

Expand Down
6 changes: 3 additions & 3 deletions maro/rl/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def _get_loss(self, experience_batch: ExperienceSet):

def learn(self, data: Union[ExperienceSet, dict]):
assert self.q_net.trainable, "q_net needs to have at least one optimizer registered."
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
if isinstance(data, ExperienceSet):
self.q_net.train()
loss = self._get_loss(data)
self.q_net.step(loss)
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
else:
self.q_net.apply(data)
self.q_net.apply(data)

def post_update(self, update_index: int):
# soft-update target network
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/algorithms/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def learn(self, data: Union[ExperienceSet, dict]):
which they are generated during the simulation. Otherwise, the return values may be meaningless.
"""
assert self.policy_net.trainable, "policy_net needs to have at least one optimizer registered."
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
# If data is an ExperienceSet, get DQN loss from the batch and backprop it throught the network.
if isinstance(data, ExperienceSet):
self.policy_net.train()
loss = self._get_loss(data)
self.policy_net.step(loss)
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
# Otherwise treat the data as a dict of gradients that can be applied directly to the network.
else:
self.policy_net.apply(data)

Expand Down
2 changes: 1 addition & 1 deletion maro/rl/model/core_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def apply(self, grad_dict: dict):
def step(self, loss: torch.tensor):
"""Use the loss to back-propagate gradients and apply them to the underlying parameters.
This is equivalent to a chained ``get_gradients`` and ``step``.
This is equivalent to a chained ``get_gradients`` and ``step``.
Args:
loss: Result of a computation graph that involves the underlying parameters.
Expand Down

0 comments on commit 8402253

Please sign in to comment.