Skip to content

Commit

Permalink
fix(nyz): fix marl nstep td compatibility bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Apr 24, 2024
1 parent 8392206 commit c7c3bac
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions ding/policy/madqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
td_error_per_sample = []
for t in range(self._cfg.collect.unroll_len):
v_data = v_nstep_td_data(
total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma
total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], None
)
# calculate v_nstep_td critic_loss
loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep)
Expand Down Expand Up @@ -231,8 +231,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
cooperation_loss_all = []
for t in range(self._cfg.collect.unroll_len):
v_data = v_nstep_td_data(
cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t],
data['weight'], self._gamma
cooperation_total_q[t],
cooperation_target_total_q[t],
data['reward'][t],
data['done'][t],
data['weight'],
None,
)
cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep)
cooperation_loss_all.append(cooperation_loss)
Expand Down

0 comments on commit c7c3bac

Please sign in to comment.