-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* PPO, SAC, DDPG passed * Explore in SAC * Test GYM on server * Sync server changes * pre-commit * Ready to try on server * . * . * . * . * . * Performance OK * Move to tests * Remove old versions * PPO done * Start to test AC * Start to test SAC * SAC test passed * update for some PR comments; Add a MARKDOWN file (#576) Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com> * Use FullyConnected to replace mlp * Update action bound * Pre-commit --------- Co-authored-by: Jinyu-W <53509467+Jinyu-W@users.noreply.github.com> Co-authored-by: Jinyu Wang <wang.jinyu@microsoft.com>
- Loading branch information
1 parent
eb6324c
commit 214383f
Showing
24 changed files
with
604 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from typing import Tuple | ||
|
||
import numpy as np | ||
import torch | ||
from torch.distributions import Normal | ||
from torch.optim import Adam | ||
|
||
from maro.rl.model import ContinuousACBasedNet, VNet | ||
from maro.rl.model.fc_block import FullyConnected | ||
from maro.rl.policy import ContinuousRLPolicy | ||
from maro.rl.training.algorithms import ActorCriticParams, ActorCriticTrainer | ||
|
||
actor_net_conf = { | ||
"hidden_dims": [64, 64], | ||
"activation": torch.nn.Tanh, | ||
} | ||
critic_net_conf = { | ||
"hidden_dims": [64, 64], | ||
"activation": torch.nn.Tanh, | ||
} | ||
actor_learning_rate = 3e-4 | ||
critic_learning_rate = 1e-3 | ||
|
||
|
||
class MyContinuousACBasedNet(ContinuousACBasedNet): | ||
def __init__(self, state_dim: int, action_dim: int) -> None: | ||
super(MyContinuousACBasedNet, self).__init__(state_dim=state_dim, action_dim=action_dim) | ||
|
||
log_std = -0.5 * np.ones(action_dim, dtype=np.float32) | ||
self._log_std = torch.nn.Parameter(torch.as_tensor(log_std)) | ||
self._mu_net = FullyConnected( | ||
input_dim=state_dim, | ||
hidden_dims=actor_net_conf["hidden_dims"], | ||
output_dim=action_dim, | ||
activation=actor_net_conf["activation"], | ||
) | ||
self._optim = Adam(self.parameters(), lr=actor_learning_rate) | ||
|
||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: | ||
distribution = self._distribution(states) | ||
actions = distribution.sample() | ||
logps = distribution.log_prob(actions).sum(axis=-1) | ||
return actions, logps | ||
|
||
def _get_states_actions_logps_impl(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: | ||
distribution = self._distribution(states) | ||
logps = distribution.log_prob(actions).sum(axis=-1) | ||
return logps | ||
|
||
def _distribution(self, states: torch.Tensor) -> Normal: | ||
mu = self._mu_net(states.float()) | ||
std = torch.exp(self._log_std) | ||
return Normal(mu, std) | ||
|
||
|
||
class MyVCriticNet(VNet): | ||
def __init__(self, state_dim: int) -> None: | ||
super(MyVCriticNet, self).__init__(state_dim=state_dim) | ||
self._critic = FullyConnected( | ||
input_dim=state_dim, | ||
output_dim=1, | ||
hidden_dims=critic_net_conf["hidden_dims"], | ||
activation=critic_net_conf["activation"], | ||
) | ||
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate) | ||
|
||
def _get_v_values(self, states: torch.Tensor) -> torch.Tensor: | ||
return self._critic(states.float()).squeeze(-1) | ||
|
||
|
||
def get_ac_policy( | ||
name: str, | ||
action_lower_bound: list, | ||
action_upper_bound: list, | ||
gym_state_dim: int, | ||
gym_action_dim: int, | ||
) -> ContinuousRLPolicy: | ||
return ContinuousRLPolicy( | ||
name=name, | ||
action_range=(action_lower_bound, action_upper_bound), | ||
policy_net=MyContinuousACBasedNet(gym_state_dim, gym_action_dim), | ||
) | ||
|
||
|
||
def get_ac_trainer(name: str, state_dim: int) -> ActorCriticTrainer: | ||
return ActorCriticTrainer( | ||
name=name, | ||
reward_discount=0.99, | ||
params=ActorCriticParams( | ||
get_v_critic_net_func=lambda: MyVCriticNet(state_dim), | ||
grad_iters=80, | ||
lam=0.97, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from maro.rl.training.algorithms import PPOParams, PPOTrainer | ||
|
||
from .ac import MyVCriticNet, get_ac_policy | ||
|
||
get_ppo_policy = get_ac_policy | ||
|
||
|
||
def get_ppo_trainer(name: str, state_dim: int) -> PPOTrainer: | ||
return PPOTrainer( | ||
name=name, | ||
reward_discount=0.99, | ||
params=PPOParams( | ||
get_v_critic_net_func=lambda: MyVCriticNet(state_dim), | ||
grad_iters=80, | ||
lam=0.97, | ||
clip_ratio=0.2, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
from typing import Tuple | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.distributions import Normal | ||
from torch.optim import Adam | ||
|
||
from maro.rl.model import ContinuousSACNet, QNet | ||
from maro.rl.model.fc_block import FullyConnected | ||
from maro.rl.policy import ContinuousRLPolicy | ||
from maro.rl.training.algorithms import SoftActorCriticParams, SoftActorCriticTrainer | ||
|
||
actor_net_conf = { | ||
"hidden_dims": [64, 64], | ||
"activation": torch.nn.Tanh, | ||
} | ||
critic_net_conf = { | ||
"hidden_dims": [64, 64], | ||
"activation": torch.nn.Tanh, | ||
} | ||
actor_learning_rate = 3e-4 | ||
critic_learning_rate = 1e-3 | ||
|
||
LOG_STD_MAX = 2 | ||
LOG_STD_MIN = -20 | ||
|
||
|
||
class MyContinuousSACNet(ContinuousSACNet): | ||
def __init__(self, state_dim: int, action_dim: int, action_limit: float) -> None: | ||
super(MyContinuousSACNet, self).__init__(state_dim=state_dim, action_dim=action_dim) | ||
|
||
self._net = FullyConnected( | ||
input_dim=state_dim, | ||
output_dim=actor_net_conf["hidden_dims"][-1], | ||
hidden_dims=actor_net_conf["hidden_dims"][:-1], | ||
activation=actor_net_conf["activation"], | ||
output_activation=actor_net_conf["activation"], | ||
) | ||
self._mu = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim) | ||
self._log_std = torch.nn.Linear(actor_net_conf["hidden_dims"][-1], action_dim) | ||
self._action_limit = action_limit | ||
self._optim = Adam(self.parameters(), lr=actor_learning_rate) | ||
|
||
def _get_actions_with_logps_impl(self, states: torch.Tensor, exploring: bool) -> Tuple[torch.Tensor, torch.Tensor]: | ||
net_out = self._net(states.float()) | ||
mu = self._mu(net_out) | ||
log_std = torch.clamp(self._log_std(net_out), LOG_STD_MIN, LOG_STD_MAX) | ||
std = torch.exp(log_std) | ||
|
||
pi_distribution = Normal(mu, std) | ||
pi_action = pi_distribution.rsample() if exploring else mu | ||
|
||
logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1) | ||
logp_pi -= (2 * (np.log(2) - pi_action - F.softplus(-2 * pi_action))).sum(axis=1) | ||
|
||
pi_action = torch.tanh(pi_action) * self._action_limit | ||
|
||
return pi_action, logp_pi | ||
|
||
|
||
class MyQCriticNet(QNet): | ||
def __init__(self, state_dim: int, action_dim: int) -> None: | ||
super(MyQCriticNet, self).__init__(state_dim=state_dim, action_dim=action_dim) | ||
self._critic = FullyConnected( | ||
input_dim=state_dim + action_dim, | ||
output_dim=1, | ||
hidden_dims=critic_net_conf["hidden_dims"], | ||
activation=critic_net_conf["activation"], | ||
) | ||
self._optim = Adam(self._critic.parameters(), lr=critic_learning_rate) | ||
|
||
def _get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: | ||
return self._critic(torch.cat([states, actions], dim=1).float()).squeeze(-1) | ||
|
||
|
||
def get_sac_policy( | ||
name: str, | ||
action_lower_bound: list, | ||
action_upper_bound: list, | ||
gym_state_dim: int, | ||
gym_action_dim: int, | ||
action_limit: float, | ||
) -> ContinuousRLPolicy: | ||
return ContinuousRLPolicy( | ||
name=name, | ||
action_range=(action_lower_bound, action_upper_bound), | ||
policy_net=MyContinuousSACNet(gym_state_dim, gym_action_dim, action_limit), | ||
) | ||
|
||
|
||
def get_sac_trainer(name: str, state_dim: int, action_dim: int) -> SoftActorCriticTrainer: | ||
return SoftActorCriticTrainer( | ||
name=name, | ||
reward_discount=0.99, | ||
params=SoftActorCriticParams( | ||
get_q_critic_net_func=lambda: MyQCriticNet(state_dim, action_dim), | ||
num_epochs=10, | ||
n_start_train=10000, | ||
), | ||
) |
Oops, something went wrong.