Skip to content

Commit

Permalink
Merge pull request #524 from ummavi/td3_policystats
Browse files Browse the repository at this point in the history
Add policy loss to TD3's logged statistics
  • Loading branch information
muupan committed Aug 20, 2019
2 parents 508ed07 + 10ab945 commit 8bee9da
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions chainerrl/agents/td3.py
Expand Up @@ -162,6 +162,7 @@ def __init__(
self.q2_record = collections.deque(maxlen=1000)
self.q_func1_loss_record = collections.deque(maxlen=100)
self.q_func2_loss_record = collections.deque(maxlen=100)
self.policy_loss_record = collections.deque(maxlen=100)

def sync_target_network(self):
"""Synchronize target network with current network."""
Expand Down Expand Up @@ -230,6 +231,7 @@ def update_policy(self, batch):
# Since we want to maximize Q, loss is negation of Q
loss = - F.mean(q)

self.policy_loss_record.append(float(loss.array))
self.policy_optimizer.update(lambda: loss)

def update(self, experiences, errors_out=None):
Expand Down Expand Up @@ -369,6 +371,7 @@ def get_statistics(self):
('average_q2', _mean_or_nan(self.q2_record)),
('average_q_func1_loss', _mean_or_nan(self.q_func1_loss_record)),
('average_q_func2_loss', _mean_or_nan(self.q_func2_loss_record)),
('average_policy_loss', _mean_or_nan(self.policy_loss_record)),
('policy_n_updates', self.policy_optimizer.t),
('q_func_n_updates', self.q_func1_optimizer.t),
]

0 comments on commit 8bee9da

Please sign in to comment.