Skip to content

Commit

Permalink
1.call allocate_trainer() at first of update(); 2.remove sending …
Browse files Browse the repository at this point in the history
…policy state to trainers after update, which can be subsituted by `allocate_trainer()`; 3.refine according to code review
  • Loading branch information
buptchan committed Aug 3, 2021
1 parent e4fc71b commit 2867583
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 41 deletions.
6 changes: 6 additions & 0 deletions maro/rl/policy/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ def get_loss(self):
self._post_step(loss.detach().cpu().numpy(), self.tracker)
return loss

def get_grad(self):
grad_dict = {}
for param_name, param in self.q_net.named_parameters():
grad_dict[param_name] = param.grad
return grad_dict

def step(self, grad_dict):
'''Backward step.'''
# set gradient & optimize
Expand Down
49 changes: 14 additions & 35 deletions maro/rl/policy/policy_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,7 @@ def __init__(
self._num_experiences_by_policy = defaultdict(int)
self.num_trainers = num_trainers

self._logger = Logger("MULTINODE_POLICY_MANAGER", dump_folder=log_dir)

self.allocate_trainers()
self._logger = Logger("MULTINODE_DIST_POLICY_MANAGER", dump_folder=log_dir)

def allocate_strategy(self, num_trainers, num_experiences_by_policy, logger=None):
policy2trainer = defaultdict(list)
Expand All @@ -410,8 +408,8 @@ def allocate_strategy(self, num_trainers, num_experiences_by_policy, logger=None

# allocate trainers according to historical experience numbers.
else:
total_num_policy = sum(num_experiences_by_policy.values())
average_payload = total_num_policy / num_trainers
total_num_experiences = sum(num_experiences_by_policy.values())
average_payload = total_num_experiences / num_trainers

offset = 0
policy_quota = dict()
Expand Down Expand Up @@ -458,29 +456,24 @@ def update(self, exp_by_policy: Dict[str, ExperienceSet]):
self._num_experiences_by_policy[policy_name] += exp.size
self._exp_cache[policy_name].extend(exp)
if (
self._exp_cache[policy_name].size >= self.update_trigger[policy_name] and
self._num_experiences_by_policy[policy_name] >= self.warmup[policy_name]
self._exp_cache[policy_name].size >= self.update_trigger[policy_name]
and self._num_experiences_by_policy[policy_name] >= self.warmup[policy_name]
):
exp_to_send[policy_name] = self._exp_cache.pop(policy_name)
updated.add(policy_name)

self.allocate_trainers()

# 1. prepare exp data for each trainer node
msg_body_by_dest = defaultdict(dict)
for policy_name, exp in exp_to_send.items():
trainer_id_list = self._policy2trainer[policy_name]
if len(trainer_id_list) == 1: # single node
trainer_id = trainer_id_list[0]
for i, trainer_id in enumerate(trainer_id_list):
if MsgKey.EXPERIENCES not in msg_body_by_dest[trainer_id]:
msg_body_by_dest[trainer_id][MsgKey.EXPERIENCES] = {}
msg_body_by_dest[trainer_id][MsgKey.EXPERIENCES][policy_name] = exp
self._logger.info(f'policy {policy_name}, exp.size = {exp.size}')
else:
for i, trainer_id in enumerate(trainer_id_list):
if MsgKey.EXPERIENCES not in msg_body_by_dest[trainer_id]:
msg_body_by_dest[trainer_id][MsgKey.EXPERIENCES] = {}
sub_exp = exp[i::len(trainer_id_list)]
msg_body_by_dest[trainer_id][MsgKey.EXPERIENCES][policy_name] = sub_exp
self._logger.info(f'policy {policy_name}, sub_exp.size = {sub_exp.size}')
sub_exp = exp[i::len(trainer_id_list)]
msg_body_by_dest[trainer_id][MsgKey.EXPERIENCES][policy_name] = sub_exp
self._logger.info(f'policy {policy_name}, exp.size = {sub_exp.size}')

# 2. scatter data and receive reply of each trainer node
trackers = []
Expand All @@ -499,12 +492,9 @@ def update(self, exp_by_policy: Dict[str, ExperienceSet]):
# 3. aggregate gradient
for policy_name, grad_dict in reply.body[MsgKey.GRAD].items():
trainer_id_list = self._policy2trainer[policy_name]
if len(trainer_id_list) == 1: # single node
manager_grad_dict[policy_name] = grad_dict
else:
for param_name in grad_dict:
manager_grad_dict[policy_name][param_name] = manager_grad_dict[policy_name].get(
param_name, 0) + grad_dict[param_name] / len(trainer_id_list)
for param_name in grad_dict:
manager_grad_dict[policy_name][param_name] = manager_grad_dict[policy_name].get(
param_name, 0) + grad_dict[param_name] / len(trainer_id_list)

# 4. apply gradient
for policy_name in manager_grad_dict:
Expand All @@ -515,24 +505,13 @@ def update(self, exp_by_policy: Dict[str, ExperienceSet]):

self.policy_dict[policy_name].step(manager_grad_dict[policy_name])

for trainer_name, policy_names in self._trainer2policies.items():
self._proxy.send(
SessionMessage(
MsgTag.UPDATE_POLICY_STATE, self._proxy.name, trainer_name,
body={MsgKey.POLICY_STATE: {name: self.policy_dict[name].get_state() for name in policy_names}}
)
)

if updated:
self._update_history.append(updated)
self._logger.info(f"Updated policies {updated}")

if self._post_update:
self._post_update(trackers)

# re-allocate
self.allocate_trainers()

def exit(self):
"""Tell the remote trainers to exit."""
self._proxy.ibroadcast("trainer", MsgTag.EXIT, SessionType.NOTIFICATION)
Expand Down
8 changes: 2 additions & 6 deletions maro/rl/policy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def trainer_node(
proxy.close()
break

if msg.tag == MsgTag.INIT_POLICY_STATE:
elif msg.tag == MsgTag.INIT_POLICY_STATE:
for name, state in msg.body[MsgKey.POLICY_STATE].items():
policy_dict[name] = create_policy_func_dict[name]()
policy_dict[name].set_state(state)
Expand Down Expand Up @@ -125,11 +125,7 @@ def trainer_node(

# Collect gradient
loss.backward()
grad_dict = {}
for param_name, param in policy_dict[name].q_net.named_parameters():
grad_dict[param_name] = param.grad

msg_body[MsgKey.GRAD][name] = grad_dict
msg_body[MsgKey.GRAD][name] = policy_dict[name].get_grad()
msg_body[MsgKey.TRACKER][name] = policy_dict[name].tracker

logger.debug(f"single step of get_loss time: {time.time() - t0}")
Expand Down

0 comments on commit 2867583

Please sign in to comment.