Skip to content

Commit

Permalink
initial impl
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Jan 3, 2023
1 parent 8e68f0b commit 2e638a5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
11 changes: 11 additions & 0 deletions all/agents/dqn.py
Expand Up @@ -36,6 +36,8 @@ def __init__(self,
minibatch_size=32,
replay_start_size=5000,
update_frequency=1,
normalized_q_target=0.1,
lr_discount=1e-6,
):
# objects
self.q = q
Expand All @@ -47,6 +49,8 @@ def __init__(self,
self.minibatch_size = minibatch_size
self.replay_start_size = replay_start_size
self.update_frequency = update_frequency
self.normalized_q_target = normalized_q_target
self.lr_discount = lr_discount
# private
self._state = None
self._action = None
Expand Down Expand Up @@ -74,11 +78,18 @@ def _train(self):
loss = self.loss(values, targets)
# backward pass
self.q.reinforce(loss)
# adjust discount
self._adjust_discount(loss.detach() / targets.var())

def _should_train(self):
self._frames_seen += 1
return (self._frames_seen > self.replay_start_size and self._frames_seen % self.update_frequency == 0)

def _adjust_discount(self, normalized_q_loss):
grad = (self.normalized_q_target - normalized_q_loss).item()
self.discount_factor = min(1, max(0, self.discount_factor + self.lr_discount * grad))
self.q._logger.add_info('normalized_q_loss', normalized_q_loss)
self.q._logger.add_info('discount_factor', self.discount_factor)

class DQNTestAgent(Agent):
def __init__(self, policy):
Expand Down
2 changes: 1 addition & 1 deletion all/presets/classic_control/dqn.py
Expand Up @@ -22,7 +22,7 @@
"target_update_frequency": 100,
# Replay buffer settings
"replay_start_size": 1000,
"replay_buffer_size": 10000,
"replay_buffer_size": 100000,
# Explicit exploration
"initial_exploration": 1.,
"final_exploration": 0.,
Expand Down

0 comments on commit 2e638a5

Please sign in to comment.