In [None]:
%pip install --upgrade open_spiel


In [None]:
import numpy as np

from open_spiel.python import rl_environment
from open_spiel.python.pytorch import dqn as dqn_pt
from open_spiel.python.jax import dqn
from open_spiel.python.algorithms import random_agent

In [None]:
def eval_against_random_bots(env, trained_agents, random_agents, num_episodes):
  """Evaluates `trained_agents` against `random_agents` for `num_episodes`."""
  num_players = len(trained_agents)
  sum_episode_rewards = np.zeros(num_players)
  for player_pos in range(num_players):
    cur_agents = random_agents[:]
    cur_agents[player_pos] = trained_agents[player_pos]
    for _ in range(num_episodes):
      time_step = env.reset()
      episode_rewards = 0
      while not time_step.last():
        player_id = time_step.observations["current_player"]
        if env.is_turn_based:
          agent_output = cur_agents[player_id].step(
              time_step, is_evaluation=True)
          action_list = [agent_output.action]
        else:
          agents_output = [
              agent.step(time_step, is_evaluation=True) for agent in cur_agents
          ]
          action_list = [agent_output.action for agent_output in agents_output]
        time_step = env.step(action_list)
        episode_rewards += time_step.rewards[player_pos]
      sum_episode_rewards[player_pos] += episode_rewards
  return sum_episode_rewards / num_episodes

In [None]:
def pt_main(
  game,
  config,
  checkpoint_dir,
  num_train_episodes,
  eval_every,
  hidden_layers_sizes,
  replay_buffer_capacity,
  batch_size
):
  num_players = 2

  env = rl_environment.Environment(game, **config)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  # random agents for evaluation
  random_agents = [
      random_agent.RandomAgent(player_id=idx, num_actions=num_actions)
      for idx in range(num_players)
  ]


  hidden_layers_sizes = [int(ls) for ls in hidden_layers_sizes]
  # pylint: disable=g-complex-comprehension
  agents = [
    dqn_pt.DQN(
      player_id=idx,
      state_representation_size=info_state_size,
      num_actions=num_actions,
      hidden_layers_sizes=hidden_layers_sizes,
      replay_buffer_capacity=replay_buffer_capacity,
      batch_size=batch_size) for idx in range(num_players)
  ]
  result = []
  for ep in range(num_train_episodes):
    if (ep + 1) % eval_every == 0:
      r_mean = eval_against_random_bots(env, agents, random_agents, 1000)
      result.append(r_mean)
      print("[%s] Mean episode rewards %s" %(ep + 1, r_mean))

    time_step = env.reset()
    while not time_step.last():
      player_id = time_step.observations["current_player"]
      if env.is_turn_based:
        agent_output = agents[player_id].step(time_step)
        action_list = [agent_output.action]
      else:
        agents_output = [agent.step(time_step) for agent in agents]
        action_list = [agent_output.action for agent_output in agents_output]
      time_step = env.step(action_list)

    # Episode is over, step all agents with final info state.
    for agent in agents:
      agent.step(time_step)
  return result

In [None]:
def jax_main(
  game,
  config,
  checkpoint_dir,
  num_train_episodes,
  eval_every,
  hidden_layers_sizes,
  replay_buffer_capacity,
  batch_size
):
  num_players = 2

  env = rl_environment.Environment(game, **config)
  info_state_size = env.observation_spec()["info_state"][0]
  num_actions = env.action_spec()["num_actions"]

  # random agents for evaluation
  random_agents = [
      random_agent.RandomAgent(player_id=idx, num_actions=num_actions)
      for idx in range(num_players)
  ]

  hidden_layers_sizes = [int(ls) for ls in hidden_layers_sizes]
  # pylint: disable=g-complex-comprehension
  agents = [
      dqn.DQN(
        player_id=idx,
        state_representation_size=info_state_size,
        num_actions=num_actions,
        hidden_layers_sizes=hidden_layers_sizes,
        replay_buffer_capacity=replay_buffer_capacity,
        batch_size=batch_size
      ) for idx in range(num_players)
  ]

  result_jax = []
  for ep in range(num_train_episodes):
    if (ep + 1) % eval_every == 0:
      r_mean = eval_against_random_bots(env, agents, random_agents, 1000)
      result_jax.append(r_mean)
      print("[%s] Mean episode rewards %s" %(ep + 1, r_mean))

    time_step = env.reset()
    while not time_step.last():
      player_id = time_step.observations["current_player"]
      if env.is_turn_based:
        agent_output = agents[player_id].step(time_step)
        action_list = [agent_output.action]
      else:
        agents_output = [agent.step(time_step) for agent in agents]
        action_list = [agent_output.action for agent_output in agents_output]
      time_step = env.step(action_list)

    # Episode is over, step all agents with final info state.
    for agent in agents:
      agent.step(time_step)
  return result_jax

In [None]:
checkpoint_dir = "/tmp/dqn_test"
num_train_episodes = 10000
eval_every = 100

hidden_layers_sizes = [64, 64]
replay_buffer_capacity = int(1e5)
batch_size = 32

# BREAKTHROUGH

In [None]:
game = "breakthrough"
config = {"columns": 5, "rows": 5}

In [None]:
pt_result = pt_main(    
    game,
    config,
    checkpoint_dir,
    num_train_episodes,
    eval_every,
    hidden_layers_sizes,
    replay_buffer_capacity,
    batch_size
)

In [None]:
import matplotlib.pyplot as plt

ep = [x for x in range(len(pt_result))]
pt_r_mean0 = [y[0] for y in pt_result]
pt_r_mean1 = [y[1] for y in pt_result]

plt.plot(ep,pt_r_mean0, c='red')
plt.plot(ep,pt_r_mean1, c='blue')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()

In [None]:
result_jax = jax_main(
    game,
    config,
    checkpoint_dir,
    num_train_episodes,
    eval_every,
    hidden_layers_sizes,
    replay_buffer_capacity,
    batch_size
)

In [None]:
ep = [x for x in range(len(result_jax))]
jax_r_mean0 = [y[0] for y in result_jax]
jax_r_mean1 = [y[1] for y in result_jax]

plt.plot(ep, jax_r_mean0, c='red')
plt.plot(ep, jax_r_mean1, c='blue')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()

In [None]:
plt.plot(ep, pt_r_mean0, c='skyblue')
plt.plot(ep, pt_r_mean1, c='skyblue', linestyle='dashed')
plt.plot(ep, jax_r_mean0, c='pink')
plt.plot(ep, jax_r_mean1, c='pink', linestyle='dashed')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()

# TIC-TAC-TOE

In [None]:
game = "tic_tac_toe"
config = {}
num_train_episodes = 20000
eval_every = 1000

In [None]:
pt_result = pt_main(
    game,
    config,
    checkpoint_dir,
    num_train_episodes,
    eval_every,
    hidden_layers_sizes,
    replay_buffer_capacity,
    batch_size
)

In [None]:
import matplotlib.pyplot as plt

ep = [x for x in range(len(pt_result))]
pt_r_mean0 = [y[0] for y in pt_result]
pt_r_mean1 = [y[1] for y in pt_result]

plt.plot(ep,pt_r_mean0, c='red')
plt.plot(ep,pt_r_mean1, c='blue')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()

In [None]:
result_jax = jax_main(
    game,
    config,
    checkpoint_dir,
    num_train_episodes,
    eval_every,
    hidden_layers_sizes,
    replay_buffer_capacity,
    batch_size
)

In [None]:
ep = [x for x in range(len(result_jax))]
jax_r_mean0 = [y[0] for y in result_jax]
jax_r_mean1 = [y[1] for y in result_jax]

plt.plot(ep, jax_r_mean0, c='red')
plt.plot(ep, jax_r_mean1, c='blue')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()

In [None]:
plt.plot(ep, pt_r_mean0, c='skyblue')
plt.plot(ep, pt_r_mean1, c='skyblue', linestyle='dashed')
plt.plot(ep, jax_r_mean0, c='pink')
plt.plot(ep, jax_r_mean1, c='pink', linestyle='dashed')
plt.xlabel('Episode')
plt.ylabel('Mean episode rewards')
plt.show()