Skip to content

Commit

Permalink
Rl v3 load save (#463)
Browse files Browse the repository at this point in the history
* added load/save feature

* fixed some bugs

* reverted unwanted changes

* lint

* fixed PR comments

Co-authored-by: ysqyang <v-yangqi@microsoft.com>
Co-authored-by: yaqiu <v-yaqiu@microsoft.com>
  • Loading branch information
3 people committed Jan 27, 2022
1 parent 4faa8f1 commit 680bb52
Show file tree
Hide file tree
Showing 24 changed files with 288 additions and 152 deletions.
8 changes: 4 additions & 4 deletions examples/rl/cim/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -95,13 +95,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
4 changes: 2 additions & 2 deletions examples/rl/cim/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> object:
def get_state(self) -> object:
return {"network": self.state_dict(), "optim": self._optim.state_dict()}

def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
assert isinstance(net_state, dict)
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
Expand Down
8 changes: 4 additions & 4 deletions examples/rl/cim/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -97,13 +97,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
8 changes: 4 additions & 4 deletions examples/rl/vm_scheduling/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down Expand Up @@ -102,13 +102,13 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> dict:
def get_state(self) -> dict:
return {
"network": self.state_dict(),
"optim": self._optim.state_dict()
}

def set_net_state(self, net_state: dict) -> None:
def set_state(self, net_state: dict) -> None:
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])

Expand Down
4 changes: 2 additions & 2 deletions examples/rl/vm_scheduling/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def apply_gradients(self, grad: dict) -> None:
param.grad = grad[name]
self._optim.step()

def get_net_state(self) -> object:
def get_state(self) -> object:
return {"network": self.state_dict(), "optim": self._optim.state_dict()}

def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
assert isinstance(net_state, dict)
self.load_state_dict(net_state["network"])
self._optim.load_state_dict(net_state["optim"])
Expand Down
2 changes: 1 addition & 1 deletion examples/rl/vm_scheduling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@

test_seed = 1024

algorithm = "dqn" # "dqn" or "ac"
algorithm = "ac" # "dqn" or "ac"
2 changes: 0 additions & 2 deletions maro/cli/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ def get_docker_compose_yml(config: dict, context: str, dockerfile_path: str, ima
}
for component, env in config_parser.get_rl_component_env_vars(config, containerized=True).items()
}
# if config["mode"] != "single":
# manifest["services"]["redis"] = {"image": "redis", "container_name": redis_host}

return manifest

Expand Down
7 changes: 4 additions & 3 deletions maro/rl/distributed/abs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,21 @@
from zmq.eventloop.zmqstream import ZMQStream

from maro.rl.utils.common import string_to_bytes
from maro.utils import Logger
from maro.utils import DummyLogger, Logger


class AbsWorker(object):
def __init__(
self,
idx: int,
router_host: str,
router_port: int = 10001
router_port: int = 10001,
logger: Logger = None
) -> None:
super(AbsWorker, self).__init__()

self._id = f"worker.{idx}"
self._logger = Logger(self._id)
self._logger = DummyLogger() if logger is None else logger

# ZMQ sockets and streams
self._context = Context.instance()
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/model/abs_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def _forward_unimplemented(self, *input: Any) -> None: # TODO
pass

@abstractmethod
def get_net_state(self) -> object:
def get_state(self) -> object:
"""
Get the net's state.
"""
raise NotImplementedError

@abstractmethod
def set_net_state(self, net_state: object) -> None:
def set_state(self, net_state: object) -> None:
"""
Set the net's state.
"""
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/policy/continuous_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def train(self) -> None:
self._policy_net.train()

def get_state(self) -> object:
return self._policy_net.get_net_state()
return self._policy_net.get_state()

def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)
self._policy_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ContinuousRLPolicy)
Expand Down
8 changes: 4 additions & 4 deletions maro/rl/policy/discrete_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def train(self) -> None:
self._q_net.train()

def get_state(self) -> object:
return self._q_net.get_net_state()
return self._q_net.get_state()

def set_state(self, policy_state: object) -> None:
self._q_net.set_net_state(policy_state)
self._q_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, ValueBasedPolicy)
Expand Down Expand Up @@ -194,10 +194,10 @@ def train(self) -> None:
self._policy_net.train()

def get_state(self) -> object:
return self._policy_net.get_net_state()
return self._policy_net.get_state()

def set_state(self, policy_state: object) -> None:
self._policy_net.set_net_state(policy_state)
self._policy_net.set_state(policy_state)

def soft_update(self, other_policy: RLPolicy, tau: float) -> None:
assert isinstance(other_policy, DiscretePolicyGradient)
Expand Down
12 changes: 7 additions & 5 deletions maro/rl/rollout/batch_env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,17 @@ def __init__(

self._ep = 0
self._segment = 0
self._end_of_episode = True

def sample(
self, policy_state: Dict[str, object] = None, num_steps: int = -1
) -> dict:
self._logger.info(f"Collecting simulation data (episode {self._ep}, segment {self._segment})")
if self._end_of_episode:
self._ep += 1
self._segment = 1
else:
self._segment += 1
self._logger.info(f"Collecting roll-out data for episode {self._ep}, segment {self._segment}")
self._client.connect()
req = {
"type": "sample", "policy_state": policy_state, "num_steps": num_steps, "parallelism": self._parallelism
Expand All @@ -104,10 +110,6 @@ def sample(
)
self._client.close()
self._end_of_episode = any(res["end_of_episode"] for res in results)
if self._end_of_episode:
self._ep += 1
self._segment = 0

merged_experiences = list(chain(*[res["experiences"] for res in results])) # List[List[ExpElement]]
return {
"end_of_episode": self._end_of_episode,
Expand Down
6 changes: 4 additions & 2 deletions maro/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from maro.rl.distributed import AbsWorker
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
from maro.utils import Logger

from .env_sampler import AbsEnvSampler

Expand All @@ -15,10 +16,11 @@ def __init__(
idx: int,
env_sampler_creator: Callable[[], AbsEnvSampler],
router_host: str,
router_port: int = 10001
router_port: int = 10001,
logger: Logger = None
) -> None:
super(RolloutWorker, self).__init__(
idx=idx, router_host=router_host, router_port=router_port
idx=idx, router_host=router_host, router_port=router_port, logger=logger
)
self._env_sampler = env_sampler_creator()

Expand Down
28 changes: 14 additions & 14 deletions maro/rl/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def extract_ops_params(self) -> Dict[str, object]:
class DiscreteActorCriticOps(AbsTrainOps):
def __init__(
self,
name: str,
device: str,
get_policy_func: Callable[[], DiscretePolicyGradient],
get_v_critic_net_func: Callable[[], VNet],
Expand All @@ -66,6 +67,7 @@ def __init__(
min_logp: float = None
) -> None:
super(DiscreteActorCriticOps, self).__init__(
name=name,
device=device,
is_single_scenario=True,
get_policy_func=get_policy_func
Expand Down Expand Up @@ -140,19 +142,15 @@ def update_actor_with_grad(self, grad_dict: dict) -> None:
self._policy.train()
self._policy.apply_gradients(grad_dict)

def get_state(self, scope: str = "all") -> dict:
ret_dict = {}
if scope in ("all", "actor"):
ret_dict["policy_state"] = self._policy.get_state()
if scope in ("all", "critic"):
ret_dict["critic_state"] = self._v_critic_net.get_net_state()
return ret_dict
def get_state(self) -> dict:
return {
"policy": self._policy.get_state(),
"critic": self._v_critic_net.get_state()
}

def set_state(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "actor"):
self._policy.set_state(ops_state_dict["policy_state"])
if scope in ("all", "critic"):
self._v_critic_net.set_net_state(ops_state_dict["critic_state"])
def set_state(self, ops_state_dict: dict) -> None:
self._policy.set_state(ops_state_dict["policy"])
self._v_critic_net.set_state(ops_state_dict["critic"])

def _preprocess_batch(self, batch: TransitionBatch) -> TransitionBatch:
assert self._is_valid_transition_batch(batch)
Expand Down Expand Up @@ -212,8 +210,10 @@ def record(self, env_idx: int, exp_element: ExpElement) -> None:
)
memory.put(transition_batch)

def get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(get_policy_func=self._get_policy_func, **self._params.extract_ops_params())
def get_local_ops_by_name(self, name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(
name=name, get_policy_func=self._get_policy_func, **self._params.extract_ops_params()
)

def _get_batch(self) -> TransitionBatch:
batch_list = [memory.sample(-1) for memory in self._replay_memory_dict.values()]
Expand Down
36 changes: 17 additions & 19 deletions maro/rl/training/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class DDPGOps(AbsTrainOps):

def __init__(
self,
name: str,
device: str,
get_policy_func: Callable[[], ContinuousRLPolicy],
get_q_critic_net_func: Callable[[], QNet],
Expand All @@ -68,6 +69,7 @@ def __init__(
soft_update_coef: float = 1.0
) -> None:
super(DDPGOps, self).__init__(
name=name,
device=device,
is_single_scenario=True,
get_policy_func=get_policy_func
Expand Down Expand Up @@ -145,23 +147,19 @@ def update_actor(self, batch: TransitionBatch) -> None:
self._policy.train()
self._policy.step(self._get_actor_loss(batch))

def get_state(self, scope: str = "all") -> dict:
ret_dict = {}
if scope in ("all", "actor"):
ret_dict["policy_state"] = self._policy.get_state()
ret_dict["target_policy_state"] = self._target_policy.get_state()
if scope in ("all", "critic"):
ret_dict["critic_state"] = self._q_critic_net.get_net_state()
ret_dict["target_critic_state"] = self._target_q_critic_net.get_net_state()
return ret_dict

def set_state(self, ops_state_dict: dict, scope: str = "all") -> None:
if scope in ("all", "actor"):
self._policy.set_state(ops_state_dict["policy_state"])
self._target_policy.set_state(ops_state_dict["target_policy_state"])
if scope in ("all", "critic"):
self._q_critic_net.set_net_state(ops_state_dict["critic_state"])
self._target_q_critic_net.set_net_state(ops_state_dict["target_critic_state"])
def get_state(self) -> dict:
return {
"policy": self._policy.get_state(),
"target_policy": self._target_policy.get_state(),
"critic": self._q_critic_net.get_state(),
"target_critic": self._target_q_critic_net.get_state()
}

def set_state(self, ops_state_dict: dict) -> None:
self._policy.set_state(ops_state_dict["policy"])
self._target_policy.set_state(ops_state_dict["target_policy"])
self._q_critic_net.set_state(ops_state_dict["critic"])
self._target_q_critic_net.set_state(ops_state_dict["target_critic"])

def soft_update_target(self) -> None:
self._target_policy.soft_update(self._policy, self._soft_update_coef)
Expand Down Expand Up @@ -207,8 +205,8 @@ def record(self, env_idx: int, exp_element: ExpElement) -> None:
)
self._replay_memory.put(transition_batch)

def get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DDPGOps(get_policy_func=self._get_policy_func, **self._params.extract_ops_params())
def get_local_ops_by_name(self, name: str) -> AbsTrainOps:
return DDPGOps(name=name, get_policy_func=self._get_policy_func, **self._params.extract_ops_params())

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)
Expand Down
Loading

0 comments on commit 680bb52

Please sign in to comment.