Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine RL workflow & tune RL models under GYM #577

Merged
merged 49 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
6bc7f0b
PPO, SAC, DDPG passed
lihuoran Apr 12, 2022
b3f5aef
Explore in SAC
lihuoran Apr 12, 2022
5dab711
Test GYM on server
lihuoran Apr 22, 2022
211c06f
Sync server changes
lihuoran Jan 17, 2023
f92f7f1
Merge branch 'v0.3' into rl_benchmark_debug
lihuoran Jan 17, 2023
514250a
pre-commit
lihuoran Jan 17, 2023
fc0c02d
Ready to try on server
lihuoran Jan 17, 2023
9fcdf42
.
lihuoran Jan 17, 2023
01b5a94
.
lihuoran Jan 17, 2023
dd27eed
.
lihuoran Jan 17, 2023
1c8f258
.
lihuoran Jan 17, 2023
1aa1085
.
lihuoran Jan 17, 2023
148af38
Performance OK
lihuoran Jan 18, 2023
99ff7b9
Move to tests
lihuoran Jan 18, 2023
65ba1a1
Remove old versions
lihuoran Jan 18, 2023
f4a85b8
PPO done
lihuoran Jan 18, 2023
2349191
Start to test AC
lihuoran Jan 18, 2023
f6f7dae
Start to test SAC
lihuoran Jan 18, 2023
110fec4
SAC test passed
lihuoran Jan 28, 2023
2a1ccd5
Multiple round in evaluation
lihuoran Jan 28, 2023
c371220
Modify config.yml
lihuoran Jan 28, 2023
a65d902
Add Callbacks
lihuoran Jan 28, 2023
aa484f8
[wip] SAC performance not good
lihuoran Jan 30, 2023
84ec6e6
[wip] still not good
lihuoran Jan 30, 2023
0ceaac4
update for some PR comments; Add a MARKDOWN file (#576)
Jinyu-W Jan 31, 2023
aad41d9
Use FullyConnected to replace mlp
lihuoran Jan 31, 2023
8884231
Update action bound
lihuoran Jan 31, 2023
0a01fb1
Merge branch 'rl_benchmark_debug' into rl_workflow_refine
lihuoran Jan 31, 2023
0bd25ca
???
lihuoran Jan 31, 2023
8781dd6
Change gym env wrapper metrics logci
lihuoran Jan 31, 2023
7b9b698
Change gym env wrapper metrics logci
lihuoran Jan 31, 2023
52b4d1d
refine env_sampler.sample under step mode
lihuoran Feb 1, 2023
a3fea0d
Add DDPG. Performance not good...
lihuoran Feb 1, 2023
23f39d1
Add DDPG. Performance not good...
lihuoran Feb 1, 2023
9da8b90
wip
lihuoran Feb 1, 2023
fb11c31
Sounds like sac works
lihuoran Feb 1, 2023
d7d3282
Refactor file structure
lihuoran Feb 1, 2023
ea26275
Refactor file structure
lihuoran Feb 1, 2023
8881a1c
Refactor file structure
lihuoran Feb 1, 2023
b4db842
Pre-commit
lihuoran Feb 6, 2023
8874a65
Merge branch 'rl_benchmark_debug' into rl_workflow_refine
lihuoran Feb 6, 2023
2a7334b
Merge branch 'v0.3' into rl_workflow_refine
lihuoran Feb 6, 2023
eb7ae9b
Pre commit
lihuoran Feb 6, 2023
627b7d1
Minor refinement of CIM RL
lihuoran Feb 8, 2023
8386312
Jinyu/rl workflow refine (#578)
Jinyu-W Feb 8, 2023
b05c849
Resolve PR comments
lihuoran Feb 9, 2023
ab5e675
Compare PPO with spinning up (#579)
lihuoran Feb 9, 2023
e180f10
SAC Test parameters update (#580)
Jinyu-W Feb 13, 2023
9371949
Episode truncation & early stopping (#581)
lihuoran Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ htmlcov/
.coveragerc
.tmp/
.xmake/
outputs/
tests/rl_log/
22 changes: 16 additions & 6 deletions examples/cim/rl/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,21 @@ def post_collect(self, info_list: list, ep: int) -> None:
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")

# print the average env metric
if len(info_list) > 1:
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")

self.metrics.update(avg_metric)

def post_evaluate(self, info_list: list, ep: int) -> None:
self.post_collect(info_list, ep)
# print the env metric from each rollout worker
for info in info_list:
print(f"env summary (episode {ep}): {info['env_metric']}")

# average env metric
metric_keys, num_envs = info_list[0]["env_metric"].keys(), len(info_list)
avg_metric = {key: sum(info["env_metric"][key] for info in info_list) / num_envs for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")

self.metrics.update({"val/" + k: v for k, v in avg_metric.items()})
4 changes: 2 additions & 2 deletions examples/rl/cim.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

job: cim_rl_workflow
scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
log_path: "outputs/cim_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
Expand All @@ -27,7 +27,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
path: "outputs/cim_rl/checkpoints"
interval: 5
logging:
stdout: INFO
Expand Down
4 changes: 2 additions & 2 deletions examples/rl/cim_distributed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

job: cim_rl_workflow
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: runtime error

scenario_path: "examples/cim/rl"
log_path: "log/rl_job/cim.txt"
log_path: "outputs/cim_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
Expand All @@ -35,7 +35,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/cim"
path: "outputs/cim_rl/checkpoints"
interval: 5
proxy:
host: "127.0.0.1"
Expand Down
4 changes: 2 additions & 2 deletions examples/rl/vm_scheduling.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

job: vm_scheduling_rl_workflow
scenario_path: "examples/vm_scheduling/rl"
log_path: "log/rl_job/vm_scheduling.txt"
log_path: "outputs/vm_rl/"
main:
num_episodes: 30 # Number of episodes to run. Each episode is one cycle of roll-out and training.
num_steps: null
Expand All @@ -27,7 +27,7 @@ training:
load_path: null
load_episode: null
checkpointing:
path: "checkpoint/rl_job/vm_scheduling"
path: "outputs/vm_rl/checkpoints"
interval: 5
logging:
stdout: INFO
Expand Down
9 changes: 7 additions & 2 deletions maro/rl/rollout/batch_env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,13 @@ def sample(
"info": [res["info"][0] for res in results],
}

def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
req = {"type": "eval", "policy_state": policy_state, "index": self._ep} # -1 signals test
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
req = {
"type": "eval",
"policy_state": policy_state,
"index": self._ep,
"num_eval_episodes": num_episodes,
} # -1 signals test
results = self._controller.collect(req, self._eval_parallelism)
return {
"info": [res["info"][0] for res in results],
Expand Down
206 changes: 110 additions & 96 deletions maro/rl/rollout/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def __init__(
agent_wrapper_cls: Type[AbsAgentWrapper] = SimpleAgentWrapper,
reward_eval_delay: int = None,
) -> None:
assert learn_env is not test_env, "Please use different envs for training and testing."

self._learn_env = learn_env
self._test_env = test_env

Expand All @@ -267,6 +269,7 @@ def __init__(
self._reward_eval_delay = reward_eval_delay

self._info: dict = {}
self.metrics: dict = {}

assert self._reward_eval_delay is None or self._reward_eval_delay >= 0

Expand Down Expand Up @@ -430,65 +433,71 @@ def sample(
Returns:
A dict that contains the collected experiences and additional information.
"""
# Init the env
self._switch_env(self._learn_env)
if self._end_of_episode:
self._reset()

# Update policy state if necessary
if policy_state is not None:
steps_to_go = num_steps
if policy_state is not None: # Update policy state if necessary
self.set_policy_state(policy_state)
self._switch_env(self._learn_env) # Init the env
self._agent_wrapper.explore() # Collect experience

# Collect experience
self._agent_wrapper.explore()
steps_to_go = float("inf") if num_steps is None else num_steps
while not self._end_of_episode and steps_to_go > 0:
# Get agent actions and translate them to env actions
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)

# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)
if self._end_of_episode:
self._reset()

# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))

if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_step(cache_element)
self._append_cache_element(cache_element)
steps_to_go -= 1
self._append_cache_element(None)

tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
experiences: List[ExpElement] = []
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
# !: Here the reward calculation method requires the given tick is enough and must be used then.
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_step(cache_element)
experiences.append(cache_element.make_exp_element())

self._agent_last_index = {
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
}
total_experiences = []
# If steps_to_go is None, run until the end of episode
# If steps_to_go is not None, run until we collect required number of steps
while (steps_to_go is None and not self._end_of_episode) or (steps_to_go is not None and steps_to_go > 0):
if self._end_of_episode:
self._reset()

while not self._end_of_episode and (steps_to_go is None or steps_to_go > 0):
# Get agent actions and translate them to env actions
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)

# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)

# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))

if self._reward_eval_delay is None:
self._calc_reward(cache_element)
self._post_step(cache_element)
self._append_cache_element(cache_element)
if steps_to_go is not None:
steps_to_go -= 1
self._append_cache_element(None)
lihuoran marked this conversation as resolved.
Show resolved Hide resolved

tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
experiences: List[ExpElement] = []
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
lihuoran marked this conversation as resolved.
Show resolved Hide resolved
cache_element = self._trans_cache.pop(0)
# !: Here the reward calculation method requires the given tick is enough and must be used then.
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_step(cache_element)
experiences.append(cache_element.make_exp_element())

self._agent_last_index = {
k: v - len(experiences) for k, v in self._agent_last_index.items() if v >= len(experiences)
}

total_experiences += experiences

return {
"end_of_episode": self._end_of_episode,
"experiences": [experiences],
"experiences": [total_experiences],
"info": [deepcopy(self._info)], # TODO: may have overhead issues. Leave to future work.
}

Expand All @@ -514,50 +523,55 @@ def load_policy_state(self, path: str) -> List[str]:

return loaded

def eval(self, policy_state: Dict[str, Dict[str, Any]] = None) -> dict:
def eval(self, policy_state: Dict[str, Dict[str, Any]] = None, num_episodes: int = 1) -> dict:
self._switch_env(self._test_env)
self._reset()
if policy_state is not None:
self.set_policy_state(policy_state)

self._agent_wrapper.exploit()
while not self._end_of_episode:
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)

# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)
info_list = []

# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))

if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

self._append_cache_element(cache_element)
self._append_cache_element(None)

tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

return {"info": [self._info]}
for _ in range(num_episodes):
self._reset()
if policy_state is not None:
self.set_policy_state(policy_state)

self._agent_wrapper.exploit()
while not self._end_of_episode:
action_dict = self._agent_wrapper.choose_actions(self._agent_state_dict)
env_action_dict = self._translate_to_env_action(action_dict, self._event)

# Store experiences in the cache
cache_element = CacheElement(
tick=self.env.tick,
event=self._event,
state=self._state,
agent_state_dict=self._select_trainable_agents(self._agent_state_dict),
action_dict=self._select_trainable_agents(action_dict),
env_action_dict=self._select_trainable_agents(env_action_dict),
# The following will be generated later
reward_dict={},
terminal_dict={},
next_state=None,
next_agent_state_dict={},
)

# Update env and get new states (global & agent)
self._step(list(env_action_dict.values()))

if self._reward_eval_delay is None: # TODO: necessary to calculate reward in eval()?
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

self._append_cache_element(cache_element)
self._append_cache_element(None)

tick_bound = self.env.tick - (0 if self._reward_eval_delay is None else self._reward_eval_delay)
while len(self._trans_cache) > 0 and self._trans_cache[0].tick <= tick_bound:
cache_element = self._trans_cache.pop(0)
if self._reward_eval_delay is not None:
self._calc_reward(cache_element)
self._post_eval_step(cache_element)

info_list.append(self._info)

return {"info": info_list}

@abstractmethod
def _post_step(self, cache_element: CacheElement) -> None:
Expand Down
2 changes: 1 addition & 1 deletion maro/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _compute(self, msg: list) -> None:
result = (
self._env_sampler.sample(policy_state=req["policy_state"], num_steps=req["num_steps"])
if req["type"] == "sample"
else self._env_sampler.eval(policy_state=req["policy_state"])
else self._env_sampler.eval(policy_state=req["policy_state"], num_episodes=req["num_eval_episodes"])
)
self._stream.send(pyobj_to_bytes({"result": result, "index": req["index"]}))
else:
Expand Down
Loading