Skip to content

Commit

Permalink
More strict shape check
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuoran committed Oct 28, 2021
1 parent a71bdfa commit 50c6743
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 90 deletions.
7 changes: 5 additions & 2 deletions maro/rl/modeling_v2/ac_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch.distributions import Categorical
Expand All @@ -13,7 +13,7 @@ def __init__(self, state_dim: int, action_num: int) -> None:
super(DiscreteActorCriticNet, self).__init__(state_dim=state_dim, action_dim=1)
self._action_num = action_num

def action_num(self) -> int:
def _get_action_num(self) -> int:
return self._action_num

def _get_actions_and_logps_exploration_impl(self, states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -29,3 +29,6 @@ def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tens
class DiscreteVActorCriticNet(VCriticMixin, DiscreteActorCriticNet, metaclass=ABCMeta):
def __init__(self, state_dim: int, action_num: int) -> None:
super(DiscreteVActorCriticNet, self).__init__(state_dim=state_dim, action_num=action_num)

def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
return self._policy_net_shape_check(states, actions)
42 changes: 29 additions & 13 deletions maro/rl/modeling_v2/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Licensed under the MIT license.

from abc import abstractmethod
from typing import Tuple
from typing import Optional, Tuple

import torch

from maro.rl.utils import match_shape


class AbsCoreModel(torch.nn.Module):
"""TODO
Expand Down Expand Up @@ -75,6 +77,8 @@ def soft_update(self, other_model: torch.nn.Module, tau: float) -> None:


class SimpleNetwork(AbsCoreModel):
"""Simple neural network that has one input and one output.
"""
def __init__(self, input_dim: int, output_dim: int) -> None:
super(SimpleNetwork, self).__init__()
self._input_dim = input_dim
Expand All @@ -96,7 +100,13 @@ def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
pass


class PolicyNetwork(AbsCoreModel):
class ShapeCheckMixin:
@abstractmethod
def _policy_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
pass


class PolicyNetwork(ShapeCheckMixin, AbsCoreModel):
def __init__(self, state_dim: int, action_dim: int) -> None:
super(PolicyNetwork, self).__init__()
self._state_dim = state_dim
Expand All @@ -110,16 +120,17 @@ def state_dim(self) -> int:
def action_dim(self) -> int:
return self._action_dim

def _is_valid_state_shape(self, states: torch.Tensor) -> bool:
return len(states.shape) == 2 and states.shape[1] == self.state_dim

def _is_valid_action_shape(self, actions: torch.Tensor) -> bool:
return len(actions.shape) == 2 and actions.shape[1] == self.action_dim
def _policy_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
return all([
match_shape(states, (None, self.state_dim)),
actions is None or match_shape(actions, (None, self.action_dim)),
actions is None or states.shape[0] == actions.shape[0]
])

def get_actions(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
assert self._is_valid_state_shape(states)
ret = self._get_actions_impl(states, exploring)
assert ret.shape == (states.shape[0], self._action_dim)
assert match_shape(ret, (states.shape[0], self._action_dim))
return ret

@abstractmethod
Expand All @@ -128,12 +139,16 @@ def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tens


class DiscretePolicyNetworkMixin:
@abstractmethod
@property
def action_num(self) -> int:
return self._get_action_num()

@abstractmethod
def _get_action_num(self) -> int:
pass


class DiscreteProbPolicyNetworkMixin(DiscretePolicyNetworkMixin):
class DiscreteProbPolicyNetworkMixin(DiscretePolicyNetworkMixin, ShapeCheckMixin):
def get_probs(self, states: torch.Tensor) -> torch.Tensor:
"""Get probabilities of all possible actions.
Expand All @@ -143,8 +158,9 @@ def get_probs(self, states: torch.Tensor) -> torch.Tensor:
Returns:
probability matrix: [batch_size, action_num]
"""
self._policy_net_shape_check(states, None)
ret = self._get_probs_impl(states)
assert ret.shape == (states.shape[0], self.action_num())
assert match_shape(ret, (states.shape[0], self.action_num))
return ret

@abstractmethod
Expand Down Expand Up @@ -174,8 +190,8 @@ def get_actions_and_logps(self, states: torch.Tensor, exploring: bool) -> Tuple[
"""
if exploring:
actions, logps = self._get_actions_and_logps_exploration_impl(states)
assert actions.shape == (states.shape[0],)
assert logps.shape == (states.shape[0],)
assert match_shape(actions, (states.shape[0],))
assert match_shape(logps, (states.shape[0],))
return actions, logps
else:
action_prob = self.get_logps(states) # [batch_size, num_actions]
Expand Down
52 changes: 34 additions & 18 deletions maro/rl/modeling_v2/critic_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from abc import ABCMeta, abstractmethod
from typing import List
from typing import List, Optional

import torch

from maro.rl.modeling_v2.base_model import AbsCoreModel
from maro.rl.utils import match_shape


class VCriticMixin:
class CriticMixin:
@abstractmethod
def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
pass


class VCriticMixin(CriticMixin):
def v_critic(self, states: torch.Tensor) -> torch.Tensor:
"""
Args:
Expand All @@ -15,16 +22,17 @@ def v_critic(self, states: torch.Tensor) -> torch.Tensor:
Returns:
v values for critic: [batch_size]
"""
assert self._critic_net_shape_check(states, None)
ret = self._get_v_critic(states)
assert ret.shape == (states.shape[0],)
assert match_shape(ret, (states.shape[0],))
return ret

@abstractmethod
def _get_v_critic(self, states: torch.Tensor) -> torch.Tensor:
pass


class QCriticMixin:
class QCriticMixin(CriticMixin):
@abstractmethod
def q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -37,20 +45,15 @@ def q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
Returns:
q values for critic: [batch_size]
"""
assert states.shape[0] == actions.shape[0]
assert self._is_valid_action_shape(actions)
assert self._critic_net_shape_check(states, actions)
ret = self._get_q_critic(states, actions)
assert ret.shape == (states.shape[0],)
assert match_shape(ret, (states.shape[0],))
return ret

@abstractmethod
def _get_q_critic(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
pass

@abstractmethod
def _is_valid_action_shape(self, actions: torch.Tensor) -> bool:
pass


class CriticNetwork(AbsCoreModel, metaclass=ABCMeta):
def __init__(self, state_dim: int) -> None:
Expand All @@ -62,13 +65,16 @@ def state_dim(self) -> int:
return self._state_dim

def _is_valid_state_shape(self, states: torch.Tensor) -> bool:
return len(states.shape) == 2 and states.shape[1] == self.state_dim
return match_shape(states, (None, self.state_dim))


class VCriticNetwork(VCriticMixin, CriticNetwork, metaclass=ABCMeta):
def __init__(self, state_dim: int) -> None:
super(VCriticNetwork, self).__init__(state_dim=state_dim)

def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
return self._is_valid_state_shape(states)


class QCriticNetwork(QCriticMixin, CriticNetwork, metaclass=ABCMeta):
def __init__(self, state_dim: int, action_dim: int) -> None:
Expand All @@ -79,8 +85,15 @@ def __init__(self, state_dim: int, action_dim: int) -> None:
def action_dim(self) -> int:
return self._action_dim

def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
return all([
self._is_valid_state_shape(states),
self._is_valid_action_shape(actions),
states.shape[0] == actions.shape[0]
])

def _is_valid_action_shape(self, actions: torch.Tensor) -> bool:
return len(actions.shape) == 2 and actions.shape[1] == self.action_dim
return match_shape(actions, (None, self.action_dim))


class DiscreteQCriticNetwork(QCriticNetwork):
Expand All @@ -107,7 +120,7 @@ def q_critic_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
"""
assert self._is_valid_state_shape(states)
ret = self._get_q_critic_for_all_actions(states)
assert ret.shape == (states.shape[0], self.action_num)
assert match_shape(ret, (states.shape[0], self.action_num))
return ret

@abstractmethod
Expand All @@ -129,13 +142,16 @@ def action_dim(self) -> int:
def agent_num(self) -> int:
return self._agent_num

def _is_valid_action_shape(self, actions: torch.Tensor) -> bool:
def _critic_net_shape_check(self, states: torch.Tensor, actions: Optional[torch.Tensor]) -> bool:
return all([
len(actions.shape) == 3,
actions.shape[1] == self.agent_num,
actions.shape[2] == self.action_dim
self._is_valid_state_shape(states),
actions is None or self._is_valid_action_shape(actions),
actions is None or states.shape[0] == actions.shape[0]
])

def _is_valid_action_shape(self, actions: torch.Tensor) -> bool:
return match_shape(actions, (None, self.agent_num, self.action_dim))


class MultiDiscreteQCriticNetwork(MultiQCriticNetwork, metaclass=ABCMeta):
def __init__(self, state_dim: int, action_nums: List[int]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion maro/rl/modeling_v2/pg_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _get_actions_and_logps_exploration_impl(self, states: torch.Tensor) -> Tuple
logps = action_probs.log_prob(actions)
return actions, logps

def action_num(self) -> int:
def _get_action_num(self) -> int:
return self._action_num

def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
Expand Down
13 changes: 6 additions & 7 deletions maro/rl/modeling_v2/q_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from .base_model import DiscretePolicyNetworkMixin, PolicyNetwork
from ..utils import match_shape


class QNetwork(PolicyNetwork):
Expand All @@ -19,11 +20,9 @@ def q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
Returns:
Q-values of shape [batch_size].
"""
assert self._is_valid_state_shape(states)
assert self._is_valid_action_shape(actions)
assert states.shape[0] == actions.shape[0]
assert self._policy_net_shape_check(states, actions)
ret = self._get_q_values(states, actions)
assert ret.shape == (states.shape[0],)
assert match_shape(ret, (states.shape[0],))
return ret

@abstractmethod
Expand All @@ -49,16 +48,16 @@ def q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
Returns:
Q-value matrix of shape [batch_size, action_num]
"""
assert self._is_valid_state_shape(states)
assert self._policy_net_shape_check(states, None)
ret = self._get_q_values_for_all_actions(states)
assert ret.shape == (states.shape[0], self.action_num())
assert match_shape(ret, (states.shape[0], self.action_num))
return ret

@abstractmethod
def _get_q_values_for_all_actions(self, states: torch.Tensor) -> torch.Tensor:
pass

def action_num(self) -> int:
def _get_action_num(self) -> int:
return self._action_num

def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
Expand Down
12 changes: 6 additions & 6 deletions maro/rl/policy_v2/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(

self._buffer = defaultdict(lambda: Buffer(size=self._max_trajectory_len))

def __call__(self, states: np.ndarray) -> List[dict]:
def _call_impl(self, states: np.ndarray) -> List[dict]:
"""Return a list of action information dict given a batch of states.
An action information dict contains the action itself, the corresponding log-P value and the corresponding
Expand All @@ -62,9 +62,6 @@ def __call__(self, states: np.ndarray) -> List[dict]:
{"action": action, "logp": logp, "value": value} for action, logp, value in zip(actions, logps, values)
]

def _is_valid_state_shape(self, states: np.ndarray) -> bool:
return len(states.shape) == 2 and states.shape[1] == self._ac_net.state_dim

def get_actions_with_logps_and_values(self, states: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
self._ac_net.eval()
states = torch.from_numpy(states).to(self._device)
Expand Down Expand Up @@ -112,8 +109,11 @@ def learn_with_data_parallel(self, batch: dict, worker_id_list: list) -> None:
_ = self.get_batch_loss(sub_batch, explicit_grad=True)
self.update(loss_info_by_policy[self._name])

def action_num(self) -> int:
return self._ac_net.action_num()
def _get_action_num(self) -> int:
return self._ac_net.action_num

def _get_state_dim(self) -> int:
return self._ac_net.state_dim

def record(
self,
Expand Down
17 changes: 7 additions & 10 deletions maro/rl/policy_v2/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
self._q_net_version = 0
self._target_q_net_version = 0

self._num_actions = self._q_net.action_num()
self._num_actions = self._q_net.action_num
self._reward_discount = reward_discount
self._num_epochs = num_epochs
self._update_target_every = update_target_every
Expand All @@ -226,7 +226,7 @@ def __init__(
opt[1](self._exploration_params, opt[0], **opt[2]) for opt in exploration_scheduling_options
]

def __call__(self, states: np.ndarray) -> Iterable:
def _call_impl(self, states: np.ndarray) -> Iterable:
if self._replay_memory.size < self._warmup:
return np.random.randint(self._num_actions, size=(states.shape[0] if len(states.shape) > 1 else 1,))

Expand All @@ -243,12 +243,6 @@ def __call__(self, states: np.ndarray) -> Iterable:
else:
return self._exploration_func(states, actions.cpu().numpy(), self._num_actions, **self._exploration_params)

def _is_valid_state_shape(self, states: np.ndarray) -> bool:
return len(states.shape) == 2 and states.shape[1] == self._q_net.state_dim

def _is_valid_action_shape(self, actions: np.ndarray) -> bool:
return len(actions.shape) == 2 and actions.shape[1] == 1

def data_parallel(self, *args, **kwargs) -> None:
raise NotImplementedError # TODO

Expand All @@ -259,8 +253,11 @@ def _get_q_values(self, states: np.ndarray, actions: np.ndarray) -> np.ndarray:
q_matrix = self.q_values_for_all_actions(states) # [batch_size, action_num]
return np.take_along_axis(q_matrix, actions, axis=1)

def action_num(self) -> int:
return self._q_net.action_num()
def _get_action_num(self) -> int:
return self._q_net.action_num

def _get_state_dim(self) -> int:
return self._q_net.state_dim

def record(
self,
Expand Down
Loading

0 comments on commit 50c6743

Please sign in to comment.