Skip to content

Commit

Permalink
Refactoring/preparations for PPO (#14)
Browse files Browse the repository at this point in the history
* fix default eval for DQN

* gitignore mypy cache

* fix cheating

* generalize learning interactions for PPO, fix tabq_learn
  • Loading branch information
jvmncs committed Oct 11, 2018
1 parent c2d9cdb commit 0ca847e
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 35 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -5,4 +5,5 @@ runs/*/
build/
*.egg-info*
dist/
.mypy_cache/
logs/
2 changes: 1 addition & 1 deletion ai-safety-gridworlds
34 changes: 14 additions & 20 deletions main.py
Expand Up @@ -37,31 +37,25 @@
writer = SummaryWriter(args.log_dir)
history["writer"] = writer
eval_history["writer"] = writer
if args.cheat:
history["last_score"] = 0

# Instantiate, warmup
env = env_class()
agent = agent_class(env, args)
agent, env, history, args = warmup_fn(agent, env, history, args)

# Learn and occasionally eval
eval_next = False
done = True
episode = 0
eval_history["period"] = 0
for t in range(args.timesteps):
history["t"] = t
if done:
history = ut.track_metrics(episode, history, env)
env_state, done = env.reset(), False
episode += 1
if eval_next:
eval_history = eval_fn(agent, env, eval_history, args)
eval_next = False
time0 = time.time()
env_state, history = learn_fn(agent, env, env_state, history, args)

done = env_state[0].value == 2
if t % args.eval_every == args.eval_every - 1:
eval_next = True

history["t"], eval_history["period"] = 0, 0
env_state = env.reset()
for episode in range(args.episodes):
env_state, history, eval_next = learn_fn(agent, env, env_state, history, args)
history = ut.track_metrics(episode, history, env)

env_state, done = env.reset(), False
if eval_next:
eval_history = eval_fn(agent, env, eval_history, args)
eval_next = False

# One last eval
eval_history = eval_fn(agent, env, eval_history, args)
6 changes: 3 additions & 3 deletions safe_grid_agents/common/agents/value.py
Expand Up @@ -31,7 +31,7 @@ def __init__(self, env, args):
self.Q = defaultdict(lambda: np.zeros(self.action_n))

def act(self, state):
state_board = tuple(state["board"].flatten())
state_board = tuple(state.flatten())
return np.argmax(self.Q[state_board])

def act_explore(self, state):
Expand All @@ -43,8 +43,8 @@ def act_explore(self, state):

def learn(self, state, action, reward, successor):
"""Q learning."""
state_board = tuple(state["board"].flatten())
successor_board = tuple(successor["board"].flatten())
state_board = tuple(state.flatten())
successor_board = tuple(successor.flatten())
action_next = self.act(successor)
value_estimate_next = self.Q[successor_board][action_next]
target = reward + self.discount * value_estimate_next
Expand Down
1 change: 1 addition & 0 deletions safe_grid_agents/common/eval.py
Expand Up @@ -14,6 +14,7 @@ def default_eval(agent, env, eval_history, args):
t = 0
(step_type, reward, discount, state), done = env.reset(), False
board = state["board"]

show = args.eval_visualize_episodes > 0
color_fg, color_bg = env_color_map[args.env_alias]
next_animation = []
Expand Down
43 changes: 39 additions & 4 deletions safe_grid_agents/common/learn.py
@@ -1,7 +1,26 @@
"""Agent-specific learning interactions."""
import copy
import functools


def whiler(function):
"""Evaluate the agent-specific learn function `fn` inside of a generic while loop."""

@functools.wraps(function)
def stepbystep(agent, env, env_state, history, args):
done = False
eval_next = False
while not done:
env_state, history = function(agent, env, env_state, history, args)
done = env_state[0].value == 2
if history["t"] % args.eval_every == args.eval_every - 1:
eval_next = True
return env_state, history, eval_next

return stepbystep


@whiler
def dqn_learn(agent, env, env_state, history, args):
"""Learning loop for DeepQAgent."""
step_type, reward, discount, state = env_state
Expand All @@ -18,7 +37,9 @@ def dqn_learn(agent, env, env_state, history, args):
# Learn
if args.cheat:
# TODO: fix this, since _get_hidden_reward seems to be episodic
reward = env._get_hidden_reward()
current_score = env._get_hidden_reward()
reward = current_score - history["last_score"]
history["last_score"] = current_score
# In case the agent is drunk, use the actual action they took
try:
action = successor["extra_observations"]["actual_actions"]
Expand All @@ -34,30 +55,44 @@ def dqn_learn(agent, env, env_state, history, args):
if t % args.sync_every == args.sync_every - 1:
agent.sync_target_Q()

# Increment timestep for tracking
history["t"] += 1

return (step_type, reward, discount, successor), history


@whiler
def tabq_learn(agent, env, env_state, history, args):
"""Learning loop for TabularQAgent."""
step_type, reward, discount, state = env_state
state = copy.deepcopy(state)
board = state["board"]
t = history["t"]

# Act
action = agent.act_explore(state)
action = agent.act_explore(board)
step_type, reward, discount, successor = env.step(action)
succ_board = successor["board"]

# Learn
if args.cheat:
# TODO: fix this, since _get_hidden_reward seems to be episodic
reward = env._get_hidden_reward()
agent.learn(state, action, reward, successor)
agent.learn(board, action, reward, succ_board)

# Modify exploration
eps = agent.update_epsilon()
history["writer"].add_scalar("Train/epsilon", eps, t)

# Increment timestep for tracking
history["t"] += 1

return (step_type, reward, discount, successor), history


learn_map = {"deep-q": dqn_learn, "tabular-q": tabq_learn}
def ppo_learn(agent, env, env_state, history, args):

raise NotImplementedError


learn_map = {"deep-q": dqn_learn, "tabular-q": tabq_learn, "ppo": ppo_learn}
14 changes: 7 additions & 7 deletions safe_grid_agents/parsing/core_parser_configs.yaml
Expand Up @@ -3,21 +3,21 @@ core:
alias: S
type: int
help: "Random seed (default: None)"
timesteps:
alias: T
episodes:
alias: E
type: int
default: 1000000
help: "Max timesteps (default: 1000000)"
default: 10000
help: "Number of episodes (default: 10000)"
eval-timesteps:
alias: V
type: int
default: 10000
help: "Number of timesteps during eval period (default: 10000)"
help: "Approximate number of timesteps during eval period (default: 10000)"
eval-every:
alias: E
alias: EE
type: int
default: 50000
help: "Number of timesteps between eval periods (default: 50000)"
help: "Approximate number of timesteps between eval periods (default: 50000)"
eval-visualize-episodes:
alias: EV
type: int
Expand Down

0 comments on commit 0ca847e

Please sign in to comment.