Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rl v3 load save #463

Merged
merged 6 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/rl/cim/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -95,13 +95,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
4 changes: 2 additions & 2 deletions examples/rl/cim/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> object:
def get_state(self) -> object:
return {"network": self.state_dict(), "optim": self._optim.state_dict()}

def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
assert isinstance(net_state, dict)
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
Expand Down
8 changes: 4 additions & 4 deletions examples/rl/cim/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -97,13 +97,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
8 changes: 4 additions & 4 deletions examples/rl/vm_scheduling/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -102,13 +102,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
4 changes: 2 additions & 2 deletions examples/rl/vm_scheduling/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> object:
def get_state(self) -> object:
return {"network": self.state_dict(), "optim": self._optim.state_dict()}

def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
assert isinstance(net_state, dict)
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/vm_scheduling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@

test_seed = 1024

algorithm = "dqn" # "dqn" or "ac"
algorithm = "ac" # "dqn" or "ac"
2 changes: 0 additions & 2 deletions maro/cli/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def get_docker_compose_yml(config: dict, context: str, dockerfile_path: str, ima
}
for component, env in config_parser.get_rl_component_env_vars(config, containerized=True).items()
}
# if config["mode"] != "single":
# manifest["services"]["redis"] = {"image": "redis", "container_name": redis_host}

return manifest

Expand Down
7 changes: 4 additions & 3 deletions maro/rl/distributed/abs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
from zmq.eventloop.zmqstream import ZMQStream

from maro.rl.utils.common import string_to_bytes
from maro.utils import Logger
from maro.utils import DummyLogger, Logger


class AbsWorker(object):
def __init__(
self,
idx: int,
router_host: str,
router_port: int = 10001
router_port: int = 10001,
logger: Logger = None
) -> None:
super(AbsWorker, self).__init__()

self._id = f"worker.{idx}"
self._logger = Logger(self._id)
self._logger = DummyLogger() if logger is None else logger

# ZMQ sockets and streams
self._context = Context.instance()
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/model/abs_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def _forward_unimplemented(self, *input: Any) -> None: # TODO
pass

@abstractmethod
def get_net_state(self) -> object:
def get_state(self) -> object:
"""
Get the net's state.
"""
raise NotImplementedError

@abstractmethod
def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
"""
Set the net's state.
"""
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/policy/continuous_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def train(self) -> None:
self._policy_net.train()

def get_state(self) -> object:
return self._policy_net.get_net_state()
return self._policy_net.get_state()

def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)
self._policy_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ContinuousRLPolicy)
Expand Down
8 changes: 4 additions & 4 deletions maro/rl/policy/discrete_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def train(self) -> None:
self._q_net.train()

def get_state(self) -> object:
return self._q_net.get_net_state()
return self._q_net.get_state()

def set_state(self, policy_state: object) -> None:
self._q_net.set_net_state(policy_state)
self._q_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ValueBasedPolicy)
Expand Down Expand Up @@ -194,10 +194,10 @@ def train(self) -> None:
self._policy_net.train()

def get_state(self) -> object:
return self._policy_net.get_net_state()
return self._policy_net.get_state()

def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)
self._policy_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, DiscretePolicyGradient)
Expand Down
12 changes: 7 additions & 5 deletions maro/rl/rollout/batch_env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,17 @@ def __init__(

self._ep = 0
self._segment = 0
self._end_of_episode = True

def sample(
self, policy_state: Dict[str, object] = None, num_steps: int = -1
) -> dict:
self._logger.info(f"Collecting simulation data (episode {self._ep}, segment {self._segment})")
if self._end_of_episode:
self._ep += 1
self._segment = 1
else:
self._segment += 1
self._logger.info(f"Collecting roll-out data for episode {self._ep}, segment {self._segment}")
self._client.connect()
req = {
"type": "sample", "policy_state": policy_state, "num_steps": num_steps, "parallelism": self._parallelism
Expand All @@ -104,10 +110,6 @@ def sample(
)
self._client.close()
self._end_of_episode = any(res["end_of_episode"] for res in results)
if self._end_of_episode:
self._ep += 1
self._segment = 0

merged_experiences = list(chain(*[res["experiences"] for res in results])) # List[List[ExpElement]]
return {
"end_of_episode": self._end_of_episode,
Expand Down
6 changes: 4 additions & 2 deletions maro/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from maro.rl.distributed import AbsWorker
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
from maro.utils import Logger

from .env_sampler import AbsEnvSampler

Expand All @@ -15,10 +16,11 @@ def __init__(
idx: int,
env_sampler_creator: Callable[[], AbsEnvSampler],
router_host: str,
router_port: int = 10001
router_port: int = 10001,
logger: Logger = None
) -> None:
super(RolloutWorker, self).__init__(
idx=idx, router_host=router_host, router_port=router_port
idx=idx, router_host=router_host, router_port=router_port, logger=logger
)
self._env_sampler = env_sampler_creator()

Expand Down
28 changes: 14 additions & 14 deletions maro/rl/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def extract_ops_params(self) -> Dict[str, object]:
class DiscreteActorCriticOps(AbsTrainOps):
def __init__(
self,
name: str,
device: str,
get_policy_func: Callable[[], DiscretePolicyGradient],
get_v_critic_net_func: Callable[[], VNet],
Expand All @@ -66,6 +67,7 @@ def __init__(
min_logp: float = None
) -> None:
super(DiscreteActorCriticOps, self).__init__(
name=name,
device=device,
is_single_scenario=True,
get_policy_func=get_policy_func
Expand Down Expand Up @@ -140,19 +142,15 @@ def update_actor_with_grad(self, grad_dict: dict) -> None:
self._policy.train()
self._policy.apply_gradients(grad_dict)

def get_state(self, scope: str = "all") -> dict:
ret_dict = {}
if scope in ("all", "actor"):
ret_dict["policy_state"] = self._policy.get_state()
if scope in ("all", "critic"):
ret_dict["critic_state"] = self._v_critic_net.get_net_state()
return ret_dict
def get_state(self) -> dict:
return {
"policy": self._policy.get_state(),
"critic": self._v_critic_net.get_state()
}

def set_state(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "actor"):
self._policy.set_state(ops_state_dict["policy_state"])
if scope in ("all", "critic"):
self._v_critic_net.set_net_state(ops_state_dict["critic_state"])
def set_state(self, ops_state_dict: dict) -> None:
self._policy.set_state(ops_state_dict["policy"])
self._v_critic_net.set_state(ops_state_dict["critic"])

def _preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
assert self._is_valid_transition_batch(batch)
Expand Down Expand Up @@ -212,8 +210,10 @@ def record(self, env_idx: int, exp_element: ExpElement) -> None:
)
memory.put(transition_batch)

def get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(get_policy_func=self._get_policy_func, **self._params.extract_ops_params())
def get_local_ops_by_name(self, name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(
name=name, get_policy_func=self._get_policy_func, **self._params.extract_ops_params()
)

def _get_batch(self) -> TransitionBatch:
batch_list = [memory.sample(-1) for memory in self._replay_memory_dict.values()]
Expand Down
36 changes: 17 additions & 19 deletions maro/rl/training/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DDPGOps(AbsTrainOps):

def __init__(
self,
name: str,
device: str,
get_policy_func: Callable[[], ContinuousRLPolicy],
get_q_critic_net_func: Callable[[], QNet],
Expand All @@ -68,6 +69,7 @@ def __init__(
soft_update_coef: float = 1.0
) -> None:
super(DDPGOps, self).__init__(
name=name,
device=device,
is_single_scenario=True,
get_policy_func=get_policy_func
Expand Down Expand Up @@ -145,23 +147,19 @@ def update_actor(self, batch: TransitionBatch) -> None:
self._policy.train()
self._policy.step(self._get_actor_loss(batch))

def get_state(self, scope: str = "all") -> dict:
ret_dict = {}
if scope in ("all", "actor"):
ret_dict["policy_state"] = self._policy.get_state()
ret_dict["target_policy_state"] = self._target_policy.get_state()
if scope in ("all", "critic"):
ret_dict["critic_state"] = self._q_critic_net.get_net_state()
ret_dict["target_critic_state"] = self._target_q_critic_net.get_net_state()
return ret_dict

def set_state(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "actor"):
self._policy.set_state(ops_state_dict["policy_state"])
self._target_policy.set_state(ops_state_dict["target_policy_state"])
if scope in ("all", "critic"):
self._q_critic_net.set_net_state(ops_state_dict["critic_state"])
self._target_q_critic_net.set_net_state(ops_state_dict["target_critic_state"])
def get_state(self) -> dict:
return {
"policy": self._policy.get_state(),
"target_policy": self._target_policy.get_state(),
"critic": self._q_critic_net.get_state(),
"target_critic": self._target_q_critic_net.get_state()
}

def set_state(self, ops_state_dict: dict) -> None:
self._policy.set_state(ops_state_dict["policy"])
self._target_policy.set_state(ops_state_dict["target_policy"])
self._q_critic_net.set_state(ops_state_dict["critic"])
self._target_q_critic_net.set_state(ops_state_dict["target_critic"])

def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)
Expand Down Expand Up @@ -207,8 +205,8 @@ def record(self, env_idx: int, exp_element: ExpElement) -> None:
)
self._replay_memory.put(transition_batch)

def get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DDPGOps(get_policy_func=self._get_policy_func, **self._params.extract_ops_params())
def get_local_ops_by_name(self, name: str) -> AbsTrainOps:
return DDPGOps(name=name, get_policy_func=self._get_policy_func, **self._params.extract_ops_params())

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
Expand Down
Loading