Skip to content

Commit

Permalink
tweak some hyperparameters
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Jan 10, 2023
1 parent 397e735 commit 22911cd
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions all/agents/dqn.py
Expand Up @@ -36,7 +36,7 @@ def __init__(self,
minibatch_size=32,
replay_start_size=5000,
update_frequency=1,
normalized_q_target=0.1,
normalized_q_target=0.05,
lr_discount=1e-4,
):
# objects
Expand Down Expand Up @@ -80,15 +80,15 @@ def _train(self):
# backward pass
self.q.reinforce(loss)
# adjust discount
self._adjust_discount(loss.detach() / targets.var())
self._adjust_discount(loss.detach() / (targets.var() + 1e-2))

def _should_train(self):
self._frames_seen += 1
return (self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0)

def _adjust_discount(self, normalized_q_loss):
grad = (1 - self.discount_factor) * (self.normalized_q_target - normalized_q_loss)
self.discount_factor = min(1, max(0, self.discount_factor + self.lr_discount * grad.item()))
self.discount_factor = min(0.999, max(0, self.discount_factor + self.lr_discount * grad.item()))
self.q._logger.add_info('normalized_q_loss', normalized_q_loss)
self.q._logger.add_info('discount_factor', self.discount_factor)

Expand Down
2 changes: 1 addition & 1 deletion all/presets/atari/dqn.py
Expand Up @@ -113,7 +113,7 @@ def agent(self, logger=DummyLogger(), train_steps=float('inf')):
policy,
replay_buffer,
discount_factor=self.hyperparameters['discount_factor'],
loss=smooth_l1_loss,
# loss=smooth_l1_loss,
minibatch_size=self.hyperparameters['minibatch_size'],
replay_start_size=self.hyperparameters['replay_start_size'],
update_frequency=self.hyperparameters['update_frequency'],
Expand Down

0 comments on commit 22911cd

Please sign in to comment.