Skip to content

Commit

Permalink
Extend policy gradient and DQN to handle case where rewards are recei…
Browse files Browse the repository at this point in the history
…ved on other players' turns.

If this is the case, the agent's step() method must be called on the environment steps where there are (or could be) rewards. See the tests for example main loops.

Fixes: #68
PiperOrigin-RevId: 271435945
Change-Id: Id3b0f30db2b83ecaf256b361dd96bd3f9b0d7430
  • Loading branch information
DeepMind Technologies Ltd authored and open_spiel@google.com committed Sep 26, 2019
1 parent 6193a8c commit 71d7663
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 4 deletions.
9 changes: 7 additions & 2 deletions open_spiel/python/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,17 @@ def step(self, time_step, is_evaluation=False, add_transition_record=True):
Returns:
A `rl_agent.StepOutput` containing the action probs and chosen action.
"""
# Act step: don't act at terminal info states.
if not time_step.last():
# Act step: don't act at terminal info states or if its not our turn.
if (not time_step.last()) and (
time_step.is_simultaneous_move() or
self.player_id == time_step.current_player()):
info_state = time_step.observations["info_state"][self.player_id]
legal_actions = time_step.observations["legal_actions"][self.player_id]
epsilon = self._get_epsilon(is_evaluation)
action, probs = self._epsilon_greedy(info_state, legal_actions, epsilon)
else:
action = None
probs = []

# Don't mess up with the state during evaluation.
if not is_evaluation:
Expand Down
42 changes: 42 additions & 0 deletions open_spiel/python/algorithms/dqn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from open_spiel.python import rl_environment
from open_spiel.python.algorithms import dqn
import pyspiel


class DQNTest(tf.test.TestCase):
Expand Down Expand Up @@ -53,6 +54,47 @@ def test_run_tic_tac_toe(self):
for agent in agents:
agent.step(time_step)

def test_run_hanabi(self):
# Hanabi is an optional game, so check we have it before running the test.
game = "hanabi"
if game not in pyspiel.registered_names():
return

num_players = 3
env_configs = {
"players": num_players,
"max_life_tokens": 1,
"colors": 2,
"ranks": 3,
"hand_size": 2,
"max_information_tokens": 3,
"discount": 0.
}
env = rl_environment.Environment(game, **env_configs)
state_size = env.observation_spec()["info_state"][0]
num_actions = env.action_spec()["num_actions"]

with self.session() as sess:
agents = [
dqn.DQN( # pylint: disable=g-complex-comprehension
sess,
player_id,
state_representation_size=state_size,
num_actions=num_actions,
hidden_layers_sizes=[16],
replay_buffer_capacity=10,
batch_size=5) for player_id in range(num_players)
]
sess.run(tf.global_variables_initializer())
time_step = env.reset()
while not time_step.last():
current_player = time_step.observations["current_player"]
agent_output = [agent.step(time_step) for agent in agents]
time_step = env.step([agent_output[current_player].action])

for agent in agents:
agent.step(time_step)


class ReplayBufferTest(tf.test.TestCase):

Expand Down
9 changes: 7 additions & 2 deletions open_spiel/python/algorithms/policy_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,16 @@ def step(self, time_step, is_evaluation=False):
Returns:
A `rl_agent.StepOutput` containing the action probs and chosen action.
"""
# Act step: don't act at terminal info states.
if not time_step.last():
# Act step: don't act at terminal info states or if its not our turn.
if (not time_step.last()) and (
time_step.is_simultaneous_move() or
self.player_id == time_step.current_player()):
info_state = time_step.observations["info_state"][self.player_id]
legal_actions = time_step.observations["legal_actions"][self.player_id]
action, probs = self._act(info_state, legal_actions)
else:
action = None
probs = []

if not is_evaluation:
self._step_counter += 1
Expand Down
45 changes: 45 additions & 0 deletions open_spiel/python/algorithms/policy_gradient_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from open_spiel.python import rl_environment
from open_spiel.python.algorithms import policy_gradient
from open_spiel.python.algorithms.losses import rl_losses
import pyspiel


class PolicyGradientTest(parameterized.TestCase, tf.test.TestCase):
Expand Down Expand Up @@ -66,6 +67,50 @@ def test_run_game(self, loss_str, game_name):
for agent in agents:
agent.step(time_step)

def test_run_hanabi(self):
# Hanabi is an optional game, so check we have it before running the test.
game = "hanabi"
if game not in pyspiel.registered_names():
return

num_players = 3
env_configs = {
"players": num_players,
"max_life_tokens": 1,
"colors": 2,
"ranks": 3,
"hand_size": 2,
"max_information_tokens": 3,
"discount": 0.
}
env = rl_environment.Environment(game, **env_configs)
info_state_size = env.observation_spec()["info_state"][0]
num_actions = env.action_spec()["num_actions"]

with self.session() as sess:
agents = [
policy_gradient.PolicyGradient( # pylint: disable=g-complex-comprehension
sess,
player_id=player_id,
info_state_size=info_state_size,
num_actions=num_actions,
hidden_layers_sizes=[8, 8],
batch_size=16,
entropy_cost=0.001,
critic_learning_rate=0.01,
pi_learning_rate=0.01,
num_critic_before_pi=4) for player_id in range(num_players)
]
sess.run(tf.global_variables_initializer())
time_step = env.reset()
while not time_step.last():
current_player = time_step.observations["current_player"]
agent_output = [agent.step(time_step) for agent in agents]
time_step = env.step([agent_output[current_player].action])

for agent in agents:
agent.step(time_step)

def test_loss_modes(self):
loss_dict = {
"qpg": rl_losses.BatchQPGLoss,
Expand Down
3 changes: 3 additions & 0 deletions open_spiel/python/rl_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def last(self):
def is_simultaneous_move(self):
return self.observations["current_player"] == SIMULTANEOUS_PLAYER_ID

def current_player(self):
return self.observations["current_player"]


class StepType(enum.Enum):
"""Defines the status of a `TimeStep` within a sequence."""
Expand Down

0 comments on commit 71d7663

Please sign in to comment.