Skip to content

Commit

Permalink
Use observe_terminal instead of act in terminals
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan committed Jul 29, 2016
1 parent efb2aa2 commit 75645c1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 50 deletions.
60 changes: 43 additions & 17 deletions agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,23 +253,24 @@ def greedy_action():

return action

def act(self, state, reward, is_state_terminal):
def act(self, state, reward):
"""
Observe a current state and a reward, then choose an action.
This function must be called every time step exept at terminal states.
"""

self.logger.debug('t:%s r:%s', self.t, reward)

if self.clip_reward:
reward = np.clip(reward, -1, 1)

if not is_state_terminal:
action = self.select_action(state)
self.t += 1
action = self.select_action(state)
self.t += 1

# Update the target network
# Global counter T is used in the original paper, but here we use
# process specific counter instead. So i_target should be set
# x-times smaller, where x is the number of processes
if self.t % self.target_update_frequency == 0:
self.sync_target_network()
# Update the target network
if self.t % self.target_update_frequency == 0:
self.sync_target_network()

if self.last_state is not None:
assert self.last_action is not None
Expand All @@ -279,14 +280,10 @@ def act(self, state, reward, is_state_terminal):
action=self.last_action,
reward=reward,
next_state=state,
is_state_terminal=is_state_terminal)
is_state_terminal=False)

if not is_state_terminal:
self.last_state = state
self.last_action = action
else:
self.last_state = None
self.last_action = None
self.last_state = state
self.last_action = action

if len(self.replay_buffer) >= self.replay_start_size and \
self.t % self.update_frequency == 0:
Expand All @@ -295,6 +292,35 @@ def act(self, state, reward, is_state_terminal):

return self.last_action

def observe_terminal(self, state, reward):
"""
Observe a terminal state and a reward.
This function must be called once when an episode terminates.
"""

if self.clip_reward:
reward = np.clip(reward, -1, 1)

assert self.last_state is not None
assert self.last_action is not None

# Add a transition to the replay buffer
self.replay_buffer.append(
state=self.last_state,
action=self.last_action,
reward=reward,
next_state=state,
is_state_terminal=True)

self.last_state = None
self.last_action = None

def stop_current_episode(self):
"""
Stop the current episode.
This function must be called once when an episode is stopped.
"""
self.last_state = None
self.last_action = None
53 changes: 25 additions & 28 deletions run_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
import numpy as np


def eval_performance(env, q_func, phi, n_runs, gpu):
def eval_performance(env, q_func, phi, n_runs, gpu, max_episode_len=None):
assert n_runs > 1, 'Computing stdev requires at least two runs'
scores = []
for i in range(n_runs):
obs = env.reset()
done = False
test_r = 0
while not done:
t = 0
while not (done or t == max_episode_len):
s = np.expand_dims(phi(obs), 0)
if gpu >= 0:
s = chainer.cuda.to_gpu(s)
qout = q_func(chainer.Variable(s), test=True)
a = qout.greedy_actions.data[0]
obs, r, done, info = env.step(a)
test_r += r
t += 1
scores.append(test_r)
print('test_{}:'.format(i), test_r)
mean = statistics.mean(scores)
Expand Down Expand Up @@ -57,44 +59,36 @@ def update_best_model(agent, outdir, t, old_max_score, new_max_score):

class Evaluator(object):

def __init__(self, reuse_env, make_env, n_runs, phi, gpu, eval_frequency,
outdir):
def __init__(self, n_runs, phi, gpu, eval_frequency,
outdir, max_episode_len=None):
self.max_score = np.finfo(np.float32).min
self.start_time = time.time()
self.eval_after_this_episode = False
self.reuse_env = reuse_env
self.make_env = make_env
self.n_runs = n_runs
self.phi = phi
self.gpu = gpu
self.eval_frequency = eval_frequency
self.outdir = outdir
self.max_episode_len = max_episode_len
self.prev_eval_t = 0

def evaluate_and_update_max_score(self, env, t, agent):
mean, median, stdev = eval_performance(
env, agent.q_function, self.phi, self.n_runs, self.gpu)
env, agent.q_function, self.phi, self.n_runs, self.gpu,
max_episode_len=self.max_episode_len)
record_stats(self.outdir, t, self.start_time, mean, median, stdev)
if mean > self.max_score:
update_best_model(agent, self.outdir, t, self.max_score, mean)
self.max_score = mean

def step(self, t, done, env, agent):
if self.reuse_env:
if t > 0 and t % self.eval_frequency == 0:
self.eval_after_this_episode = True
if self.eval_after_this_episode and done:
# Eval with the existing env
self.evaluate_and_update_max_score(env, t, agent)
self.eval_after_this_episode = False
else:
if t % self.eval_frequency == 0:
# Eval with a new env
self.evaluate_and_update_max_score(
self.make_env(True), t, agent)
def evaluate_if_necessary(self, t, env, agent):
if t >= self.prev_eval_t + self.eval_frequency:
self.evaluate_and_update_max_score(env, t, agent)
self.prev_eval_t = t


def run_dqn(agent, make_env, phi, steps, eval_n_runs, eval_frequency, gpu,
outdir, reuse_env=False, max_episode_len=None):
outdir, max_episode_len=None):

env = make_env(False)

Expand All @@ -113,29 +107,32 @@ def run_dqn(agent, make_env, phi, steps, eval_n_runs, eval_frequency, gpu,

t = 0

evaluator = Evaluator(
reuse_env=reuse_env, make_env=make_env, n_runs=eval_n_runs, phi=phi,
gpu=gpu, eval_frequency=eval_frequency, outdir=outdir)
evaluator = Evaluator(n_runs=eval_n_runs, phi=phi, gpu=gpu,
eval_frequency=eval_frequency, outdir=outdir,
max_episode_len=max_episode_len)

episode_len = 0
while t < steps:
try:
episode_r += r
action = agent.act(obs, r, done)
evaluator.step(t, done, env, agent)

if done or episode_len == max_episode_len:
if done:
agent.observe_terminal(obs, r)
else:
agent.stop_current_episode()
print('{} t:{} episode_idx:{} explorer:{} episode_r:{}'.format(
outdir, t, episode_idx, agent.explorer, episode_r))
if episode_len == max_episode_len:
agent.stop_current_episode()
evaluator.evaluate_if_necessary(t, env, agent)
# Start a new episode
episode_r = 0
episode_idx += 1
episode_len = 0
obs = env.reset()
r = 0
done = False
else:
action = agent.act(obs, r)
obs, r, done, info = env.step(action)
t += 1
episode_len += 1
Expand Down
12 changes: 7 additions & 5 deletions tests/test_dqn_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,23 @@ def _test_abc(self, gpu, discrete=True):
reward = 0.0

# Train
for i in range(5000):
t = 0
while t < 5000:
episode_r += reward
total_r += reward

action = agent.act(obs, reward, done)

if done:
print(('i:{} explorer:{} episode_r:{}'.format(
i, agent.explorer, episode_r)))
agent.observe_terminal(obs, reward)
print(('t:{} explorer:{} episode_r:{}'.format(
t, agent.explorer, episode_r)))
episode_r = 0
obs = env.reset()
done = False
reward = 0.0
else:
action = agent.act(obs, reward)
obs, reward, done, _ = env.step(action)
t += 1

# Test
total_r = 0.0
Expand Down

0 comments on commit 75645c1

Please sign in to comment.