# Setup

In [1]:
!pip install -q tf-agents

[?25l[K     |▎                               | 10kB 20.4MB/s eta 0:00:01[K     |▋                               | 20kB 27.0MB/s eta 0:00:01[K     |█                               | 30kB 20.3MB/s eta 0:00:01[K     |█▏                              | 40kB 13.3MB/s eta 0:00:01[K     |█▌                              | 51kB 8.3MB/s eta 0:00:01[K     |█▉                              | 61kB 8.6MB/s eta 0:00:01[K     |██                              | 71kB 8.8MB/s eta 0:00:01[K     |██▍                             | 81kB 9.1MB/s eta 0:00:01[K     |██▊                             | 92kB 8.8MB/s eta 0:00:01[K     |███                             | 102kB 7.3MB/s eta 0:00:01[K     |███▎                            | 112kB 7.3MB/s eta 0:00:01[K     |███▋                            | 122kB 7.3MB/s eta 0:00:01[K     |████                            | 133kB 7.3MB/s eta 0:00:01[K     |████▏                           | 143kB 7.3MB/s eta 0:00:01[K     |████▌                  

In [2]:
import numpy as np
import os

from numba import jitclass, njit, int32, int64, float32
from tqdm import tqdm
from IPython.display import clear_output

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_episode_driver
from tf_agents.environments import py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.specs import array_spec
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tf.compat.v1.enable_v2_behavior()



# Environment

## Board

In [3]:
@jitclass([('position', int64),
           ('mask', int64),
           ('n_move', int32),
           ('top', int64)])
class Connect4Board(object):
  def __init__(self, position=0, mask=0, n_move=0, top=None):
    self.position = position
    self.mask = mask
    self.n_move = n_move

    if top is None:
      self.top = np.sum(1 << 6+7*np.arange(7))
    else:
      self.top = top
  
  def copy(self):
    return Connect4Board(self.position, self.mask, self.n_move, self.top)
  
  def reset(self):
    return Connect4Board(0, 0, 0, self.top)

  def _reset(self):
    self.position = 0
    self.mask = 0
    self.n_move = 0

    return self

  def is_valid_move(self, col):
    return (self.top & (self.mask + (1 << (col*7))) == 0)
  
  def available_moves(self):
    return [col for col in range(7) if self.is_valid_move(col)]
  
  def generate_moves(self):
    return [col for col in np.random.choice(7, size=7, replace=False) if self.is_valid_move(col)]

  def make_move(self, col):
    return Connect4Board(self.position ^ self.mask,
                         self.mask | (self.mask + (1 << (col*7))),
                         self.n_move + 1,
                         self.top)

  def _make_move(self, col):
    self.position = self.position ^ self.mask
    self.mask = self.mask | (self.mask + (1 << (col*7)))
    self.n_move += 1
    return self
  
  def is_win(self):
    opposition = self.position ^ self.mask
    # Horizontal check
    m = opposition & (opposition >> 7)
    if m & (m >> 14):
        return True
    # Diagonal \
    m = opposition & (opposition >> 6)
    if m & (m >> 12):
        return True
    # Diagonal /
    m = opposition & (opposition >> 8)
    if m & (m >> 16):
        return True
    # Vertical
    m = opposition & (opposition >> 1)
    if m & (m >> 2):
        return True
    # Nothing found
    return False
  
  def is_draw(self):
    return self.n_move == 42
  
  def is_terminal(self):
    return self.is_win() or self.is_draw()
  
  def to_array(self):
    board = np.zeros((6,7), dtype=np.int32)

    opponent = self.position ^ self.mask
    p0, p1 = ((opponent, self.position) if self.n_move % 2 else 
              (self.position, opponent))
    
    for j in range(7):
      m = np.int64(1) << j*7
      for i in range(6):
        if p0 & m:
          board[i,j] = 1
        elif p1 & m:
          board[i,j] = -1
        else:
          break
        m <<= 1
    
    return board
  
  def hash(self):
    return hash((self.position, self.mask))

## Python Environment

In [4]:
class Connect4Env(py_environment.PyEnvironment):

  def __init__(self):
    self._action_spec = array_spec.BoundedArraySpec(
        shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')
    self._observation_spec = array_spec.BoundedArraySpec(
        shape=(6,7,1), dtype=np.float32, minimum=-1, maximum=1, name='observation')
    self._board = Connect4Board()
    self._episode_ended = False

  def action_spec(self):
    return self._action_spec

  def observation_spec(self):
    return self._observation_spec

  def _reset(self):
    self._board._reset()
    self._episode_ended = False
    return ts.restart(self._board.to_array()[:,:,np.newaxis].astype(np.float32))

  def _step(self, action):

    if self._episode_ended:
      # The last action ended the episode. Ignore the current action and start
      # a new episode.
      return self.reset()

    # Make sure episodes don't go on forever.
    if action < 0 or action > 6:
      raise ValueError('`action` should be between 0 to 6.')
    elif not self._board.is_valid_move(action):
      self._episode_ended = True
      return ts.termination(self._board.to_array()[:,:,np.newaxis].astype(np.float32), reward=-2.0)
    
    self._board._make_move(action)

    if self._board.is_win():
      self._episode_ended = True
      return ts.termination(self._board.to_array()[:,:,np.newaxis].astype(np.float32), reward=1.0)
    elif self._board.is_draw():
      self._episode_ended = True
      return ts.termination(self._board.to_array()[:,:,np.newaxis].astype(np.float32), reward=0.0)
    else:
      return ts.transition(self._board.to_array()[:,:,np.newaxis].astype(np.float32), reward=0.0, discount=1.0)

# Negamax

## Evaluator

In [5]:
@njit
def bitwise_or_reduce(xs):
  s = 0
  for x in xs:
    s |= x

  return s

@jitclass([('i_bottom', int32[:]),
           ('bottom', int64),
           ('cols', int64[:]),
           ('d', int32[:]),
           ('weights', float32[:,::1])])
class Connect4Evaluator(object):
  def __init__(self):
    self.i_bottom = np.int32(7) * np.arange(7, dtype=np.int32)
    self.bottom = np.sum(1 << self.i_bottom)
    self.cols = ((np.int64(1) << 7) - 1) << self.i_bottom
    self.d = np.array([7,6,8,1], dtype=np.int32)
    self.weights = np.array([[0.001,0.009,0.09,0.9],
                             [0.001,0.009,0.09,0.4]], dtype=np.float32)

  def evaluate(self, board: Connect4Board, color=1):
    # if last player won then current player lost
    if board.is_win():
      return (board.n_move + 1) // 2 - 22.0
    # game drawn
    elif board.is_draw():
      return 0.0
    # intermediate state
    else:
      mask = board.mask | board.top
      cur_n = self._evaluate(board.position, mask)
      opp_n = self._evaluate(board.position ^ board.mask, mask)
      max_score = 21.0 - board.n_move // 2
      cur_value = np.dot(cur_n, self.weights[0])
      if cur_value >= 1.0:
        return 21.0 - board.n_move // 2
      opp_value = np.dot(opp_n, self.weights[1])
      if opp_value >= 1.0:
        return (board.n_move+1) // 2 - 21.0
      return (21.0 - (board.n_move+2) // 2) * (cur_value - opp_value)
  
  def _evaluate(self, pos, mask):
    n_mask = ~mask
    nxt = mask + self.bottom

    # lij : at least i consecutive elements starting withing j places on left
    l11 = pos >> self.d
    l12 = l11 | (l11 & n_mask) >> self.d
    l13 = l12 | (l12 & n_mask) >> self.d

    l21 = l11 & (l11 >> self.d)
    l22 = l21 | (l21 & n_mask) >> self.d

    l31 = l21 & (l21 >> self.d)

    # rij : at least i consecutive elements starting withing j places on right
    r11 = pos << self.d
    r12 = r11 | (r11 & n_mask) << self.d
    r13 = r12 | (r12 & n_mask) << self.d

    r21 = r11 & (r11 << self.d)
    r22 = r21 | (r21 & n_mask) << self.d

    r31 = r21 & (r21 << self.d)

    # ci : at least i elements within a 4 element frame that contains it
    c1 = bitwise_or_reduce((l13 | r13) & n_mask)
    c2 = bitwise_or_reduce((l22 | r22 | (l11 & r12) | (l12 & r11)) & n_mask)
    c3 = bitwise_or_reduce((l31 | r31 | (l21 & r11) | (l11 & r21)) & n_mask)
    # c3t : adding one more here will make it win
    c3t = c3 & nxt

    c = np.array([[c1],[c2],[c3],[c3t]])

    n = np.sum((c & self.cols) != 0, axis=1, dtype=np.float32)

    return n

## Searcher

In [6]:
class TranspositionTableEntry(object):
  def __init__(self, value, flag, depth):
    self.value = value
    self.flag = flag
    self.depth = depth

class NegamaxSearcher(object):
  EXACT = 0
  LOWERBOUND = -1
  UPPERBOUND = 1
  def __init__(self, evaluator):
    self.evaluate = np.vectorize(evaluator.evaluate)
  
  def __call__(self, node, depth) -> int:
    # initialize tranposition table
    self.t_table = {}

    moves = np.array(node.generate_moves())
    if depth <= 0:
      return moves[0]
    children = np.array([node.make_move(move) for move in moves])

    values = self.evaluate(children, -1)
    if depth == 1:
      index = np.argmax(-values)
      return moves[index]

    order = values.argsort()

    value = (-np.inf, None)
    a, b = -np.inf, np.inf
    for move,child in zip(moves[order],children[order]):
      value = max(value, (-self._negamax(child, depth-1, -b, -a, -1), move), key=lambda x: x[0])
      a = max(a, value[0])
      if a >= b:
        break
    
    return value[1]

  def _negamax(self, node, depth, a, b, color) -> float:
    a_orig = a
    
    # Transposition Table Lookup; node is the lookup key for t_entry
    t_entry = self.t_table.get(node.hash())
    if t_entry is not None and t_entry.depth >= depth:
      if t_entry.flag == self.EXACT:
        return t_entry.value
      elif t_entry.flag == self.LOWERBOUND:
        a = max(a, t_entry.value)
      elif t_entry.flag == self.UPPERBOUND:
        b = min(b, t_entry.value)
      
      if a >= b:
        return t_entry.value

    
    if depth < 1 or node.is_terminal():
      return self.evaluate(node, color)
    
    moves = node.generate_moves()
    children = np.array([node.make_move(move) for move in moves])

    values = self.evaluate(children, -color)
    if depth == 1:
      return np.max(-values)

    order = values.argsort()

    value = -np.inf
    for child in children[order]:
      value = max(value, -self._negamax(child, depth-1, -b, -a, -color))
      a = max(a, value)
      if a >= b:
        break
    
    # Transposition Table Store; node is the lookup key for t_entry
    if value <= a_orig:
      flag = self.UPPERBOUND
    elif value >= b:
      flag = self.LOWERBOUND
    else:
      flag = self.EXACT
    
    t_entry = TranspositionTableEntry(value, flag, depth)
    self.t_table[node.hash()] = t_entry
    
    return value

# Custom Drivers

Play Against Negamax

In [7]:
class NegamaxEpisodeDriver(object):
  def __init__(self, env, policy, observers=(), num_episodes=10, depth=2, mirroring=False):
    self.env = env
    self.policy = policy
    self.observers = observers
    self.num_episodes = num_episodes
    self.depth = depth
    self.mirror = mirroring
    self.board = Connect4Board()
    self.negamax = NegamaxSearcher(Connect4Evaluator())
  
  def run(self, first):
    for _ in range(self.num_episodes):
      turn = (first == 0)
      self.board._reset()
      time_step = self.env.reset()
      action_buffer = []
      while not time_step.is_last():
        if turn:
          action_step = self.policy.action(time_step)
          tf_action = action_step.action
          py_action = tf_action.numpy()[0]
        else:
          py_action = self.negamax(self.board, self.depth)
          tf_action = tf.constant([py_action])
          action_step = policy_step.PolicyStep(tf_action)
        self.board._make_move(py_action)
        next_time_step = self.env.step(tf_action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)
        action_buffer.append(tf_action)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
        
        time_step = next_time_step
        turn = not turn
      
      if not self.mirror:
        continue
      
      time_step = self.env.reset()
      for tf_action in action_buffer:
        action = 6-tf_action
        action_step = policy_step.PolicyStep(action)
        next_time_step = self.env.step(action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
        
        time_step = next_time_step

Play Against Each-other

In [8]:
class NashEpisodeDriver(object):
  def __init__(self, env, policies, observers=(), num_episodes=10, mirroring=False):
    self.env = env
    self.policies = policies
    self.n_policies = len(policies)
    self.observers = observers
    self.num_episodes = num_episodes
    self.mirror = mirroring
  
  def run(self, first):
    for _ in range(self.num_episodes):
      turn = first % self.n_policies
      time_step = self.env.reset()
      action_buffer = []
      while not time_step.is_last():
        policy = self.policies[turn]
        action_step = policy.action(time_step)
        next_time_step = self.env.step(action_step.action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)
        action_buffer.append(action_step.action)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
      
        time_step = next_time_step
        turn = (turn + 1) % self.n_policies
      
      if not self.mirror:
        continue
      
      time_step = self.env.reset()
      for action in action_buffer:
        action = 6-action
        action_step = policy_step.PolicyStep(action)
        next_time_step = self.env.step(action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
        
        time_step = next_time_step

In [9]:
class NegamaxOnlyEpisodeDriver(object):
  def __init__(self, env, observers=(), num_episodes=10, mirroring=False):
    self.env = env
    self.observers = observers
    self.num_episodes = num_episodes
    self.mirror = mirroring
    self.board = Connect4Board()
    self.negamax = NegamaxSearcher(Connect4Evaluator())
  
  def run(self, d0, d1):
    for _ in range(self.num_episodes):
      turn = 0
      d = [d0, d1]
      self.board._reset()
      time_step = self.env.reset()
      action_buffer = []
      while not time_step.is_last():
        py_action = self.negamax(self.board, d[turn])
        tf_action = tf.constant([py_action])
        action_step = policy_step.PolicyStep(tf_action)
        self.board._make_move(py_action)
        next_time_step = self.env.step(tf_action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)
        action_buffer.append(tf_action)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
        
        time_step = next_time_step
        turn = 1 - turn
      
      if not self.mirror:
        continue
      
      time_step = self.env.reset()
      for tf_action in action_buffer:
        action = 6-tf_action
        action_step = policy_step.PolicyStep(action)
        next_time_step = self.env.step(action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)

        # Add trajectory to the replay buffer
        for observer in self.observers:
          observer(traj)
        
        time_step = next_time_step

# Custom Metric

In [10]:
class DiscountedReturnMetric(object):
  def __init__(self, buffer_size=10, gamma=1.0):
    self.buffer = np.zeros(buffer_size, dtype=np.float64)
    self.index = 0
    self.buffer_size = buffer_size
    self.gamma = gamma
  
  def init_variables(self):
    self.current_return = 0.0
    self.current_factor = 1.0
  
  def call(self, trajectory):
    if trajectory.is_first():
      self.init_variables()
    self.current_return += self.current_factor * trajectory.reward.numpy()
    self.current_factor *= self.gamma
    if trajectory.is_last():
      self.buffer[self.index] = self.current_return
      self.index = (self.index + 1) % self.buffer_size
  
  def result(self):
    return np.mean(self.buffer)
  
  def __call__(self, trajectory):
    self.call(trajectory)

# DQN Agent

## Hyperparameters

In [11]:
num_iterations = 1000000

num_collect_episodes = 2
collect_interval = 50
replay_buffer_capacity = 40000

conv_layer_params = ((128,(4,4),1),)
fc_layer_params = (64,64,)

gamma = -0.9

batch_size = 64
learning_rate = 1e-3
log_interval = 100

num_eval_episodes = 5
eval_interval = 500

checkpoint_interval = 100000
switch_interval = 1000
clear_interval = 1000

## Environment

In [12]:
train_py_env = Connect4Env()
eval_py_env = Connect4Env()

In [13]:
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

## Agent

In [14]:
q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    conv_layer_params=conv_layer_params,
    fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=global_step)
agent.initialize()

## Data Collection

In [15]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

In [16]:
negamax_collect_driver = NegamaxEpisodeDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_episodes=num_collect_episodes,
    depth=1)

In [17]:
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

random_collect_driver = NegamaxEpisodeDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_episodes=num_collect_episodes,
    depth=0)

In [18]:
selfplay_collect_driver = NashEpisodeDriver(
    train_env,
    [agent.policy, agent.collect_policy],
    observers=[replay_buffer.add_batch],
    num_episodes=num_collect_episodes)

In [19]:
# Initial data collection
random_collect_driver.run(first=True)
random_collect_driver.run(first=False)

In [20]:
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

iterator = iter(dataset)

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


## Evaluation Metrics

In [21]:
eval_metrics = [DiscountedReturnMetric(buffer_size=1, gamma=gamma)]
negamax_eval_driver = NegamaxEpisodeDriver(
    eval_env,
    agent.policy,
    observers=eval_metrics,
    num_episodes=1,
    depth=1)

random_eval_driver = NashEpisodeDriver(
    eval_env,
    [agent.policy, random_policy],
    observers=eval_metrics,
    num_episodes=num_eval_episodes)

## Checkpoint

In [22]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
checkpoint_dir = os.path.join('/content/drive/MyDrive/CSE BUET/Level-4, Term-2/CSE 472/Project', 'connect4_dqn_checkpoint2')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

In [24]:
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

In [25]:
 agent.train_step_counter.numpy()

0

## Train

In [26]:
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

def train(collect_drivers):

  turn = 0
  n_drivers = len(collect_drivers)
  losses = []
  random_returns = []
  negamax_returns = []
  step = agent.train_step_counter.numpy()
  
  for _ in tqdm(range(num_iterations)):
    
    if step % switch_interval == 0:
      collect_driver = collect_drivers[turn]
      turn = (turn + 1) % n_drivers

    # Collect a few steps using collect_policy and save to the replay buffer.
    if step % collect_interval == 0:
      collect_driver.run(first=0)
      collect_driver.run(first=1)

    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step += 1
    
    if step % log_interval == 0:
      losses.append((step, train_loss))

    if step % eval_interval == 0:
      random_eval_driver.run(first=0)
      avg_return0 = eval_metrics[0].result()
      random_eval_driver.run(first=1)
      avg_return1 = eval_metrics[0].result()
      random_returns.append((step, avg_return0, avg_return1))
      
      negamax_eval_driver.run(first=0)
      avg_return0 = eval_metrics[0].result()
      negamax_eval_driver.run(first=1)
      avg_return1 = eval_metrics[0].result()
      negamax_returns.append((step, avg_return0, avg_return1))
    
    if step % checkpoint_interval == 0:
      train_checkpointer.save(global_step)
  
  return losses, (random_returns, negamax_returns)

In [None]:
losses, returns = train([selfplay_collect_driver, random_collect_driver])

  0%|          | 0/1000000 [00:00<?, ?it/s]

Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
100%|██████████| 1000000/1000000 [6:00:33<00:00, 46.23it/s]


In [None]:
def train_v2(num_iterations, collect_driver, eval_driver, eval_metrics):

  losses = []
  random_returns = []
  negamax_returns = []
  step = agent.train_step_counter.numpy()
  
  for _ in tqdm(range(num_iterations)):

    # Collect a few steps using collect_policy and save to the replay buffer.
    if step % collect_interval == 0:
      collect_driver.run()

    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step += 1

    if step % clear_interval == 0:
      clear_output()
    
    if step % log_interval == 0:
      print('loss', train_loss.numpy())

    if step % eval_interval == 0:
      eval_driver.run()
      avg_return = eval_metrics[0].result()
      print('retrun', avg_return)
    #   random_eval_driver.run(first=0)
    #   avg_return0 = eval_metrics[0].result()
    #   random_eval_driver.run(first=1)
    #   avg_return1 = eval_metrics[0].result()
    #   random_returns.append((step, avg_return0, avg_return1))
      
    #   negamax_eval_driver.run(first=0)
    #   avg_return0 = eval_metrics[0].result()
    #   negamax_eval_driver.run(first=1)
    #   avg_return1 = eval_metrics[0].result()
    #   negamax_returns.append((step, avg_return0, avg_return1))
    
    if step % checkpoint_interval == 0:
      train_checkpointer.save(global_step)
  
  # return losses, (random_returns, negamax_returns)

In [27]:
collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_episodes=num_collect_episodes)

In [28]:
negamax_only_driver = NegamaxOnlyEpisodeDriver(
    train_env,
    observers=[replay_buffer.add_batch],
    num_episodes=1)

In [29]:
eval_metrics = [DiscountedReturnMetric(buffer_size=1, gamma=gamma)]
eval_driver = dynamic_episode_driver.DynamicEpisodeDriver(
    eval_env,
    agent.policy,
    observers=eval_metrics,
    num_episodes=1)

In [None]:
train_v2(num_iterations, collect_driver, eval_driver, eval_metrics)

loss 25.028185
retrun -100.0


  7%|▋         | 74053/1000000 [17:24<4:27:40, 57.65it/s]

In [31]:
def train_one_iteration():

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()
  
  if step % clear_interval == 0:
    clear_output()
  
  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    eval_driver.run()
    avg_return = eval_metrics[0].result()
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    
  if step % checkpoint_interval == 0:
    train_checkpointer.save(global_step)

def train_v3():
  for _ in tqdm(range(500)):
    for first in range(2):
      for _ in range(1):
        negamax_collect_driver.run(first)
        for _ in range(50):
          train_one_iteration()
      for _ in range(1):
        random_collect_driver.run(first)
        for _ in range(50):
          train_one_iteration()
      for _ in range(8):
        collect_driver.run()
        for _ in range(50):
          train_one_iteration()

In [33]:
train_v3()

step = 1000000: loss = 0.032823145389556885
step = 1000000: Average Return = 0.07976644486188889


100%|██████████| 500/500 [1:40:46<00:00, 12.09s/it]


In [34]:
train_checkpointer.save(global_step)

In [66]:
negamax_test_driver = NegamaxEpisodeDriver(
    eval_env,
    agent.policy,
    observers=eval_metrics,
    num_episodes=1,
    depth=1)
negamax_test_driver.run(first=0)
avg_return0 = eval_metrics[0].result()
negamax_test_driver.run(first=1)
avg_return1 = eval_metrics[0].result()
print(avg_return0, avg_return1)

-0.3874204754829407 0.5314409732818604


In [None]:
negamax_test_driver = NegamaxOnlyEpisodeDriver(
    eval_env,
    observers=eval_metrics,
    num_episodes=1)
negamax_test_driver.run(1,1)
avg_return = eval_metrics[0].result()
print(avg_return)

1.0


# Test

In [36]:
def print_board(a):
  a = np.flipud(a)
  c = {0:'-', 1:'O', -1:'X'}
  for i in range(a.shape[0]):
    for j in range(a.shape[1]):
      print(c[a[i,j]], end=' ')
    print()
  print()

In [63]:
board = Connect4Board()
negamax = NegamaxSearcher(Connect4Evaluator())
depth = 4

def print_board(a):
  a = np.flipud(a)
  c = {0:'-', 1:'O', -1:'X'}
  for i in range(a.shape[0]):
    for j in range(a.shape[1]):
      print(c[a[i,j]], end=' ')
    print()
  print()

def run(board, first):
  env = eval_env
  policy = agent.policy

  print_board(board.to_array())

  turn = 0
  time_step = env.reset()
  while not time_step.is_last():
    if turn != first:
      action_step = policy.action(time_step)
      tf_action = action_step.action
      py_action = tf_action.numpy()[0]
      print('dqn', end=' ')
    else:
      py_action = negamax(board, depth)
      tf_action = tf.constant([py_action])
      action_step = policy_step.PolicyStep(tf_action)
      print('negamax', end=' ')
    if not board.is_valid_move(py_action):
      return print('player', 2-turn, 'won')
    board._make_move(py_action)
    next_time_step = env.step(tf_action)

    print('move:', py_action)
    print_board(board.to_array())
    turn = 1 - turn
    time_step = next_time_step
  
  if board.is_win():
    if turn == first:
      print('dqn won')
    else:
      print('negamax won')
  else:
    print('drawn')

run(board, False)

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 

negamax move: 3
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - O - - - 

dqn move: 1
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - O - - - 

negamax move: 4
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - O O - - 

dqn move: 1
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - - - - - 
- X - O O - - 

negamax move: 5
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - - - - - 
- X - O O O - 

dqn move: 0
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - - - - - 
X X - O O O - 

negamax move: 2
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- X - - - - - 
X X O O O O - 

negamax won


In [38]:
def print_trajectory(trajectory):
  print(np.flipud(trajectory.observation.numpy().reshape((6,7))))
  print(trajectory.action.numpy())
  print(trajectory.reward.numpy())

negamax_test_driver = NegamaxEpisodeDriver(
    eval_env,
    agent.policy,
    observers=[print_trajectory],
    num_episodes=1,
    depth=1)

random_test_driver = NashEpisodeDriver(
    eval_env,
    [agent.policy, random_policy],
    observers=[print_trajectory],
    num_episodes=1)

agent_test_driver = NashEpisodeDriver(
    eval_env,
    [negamax_collect_driver.policy,agent.policy],
    observers=[print_trajectory],
    num_episodes=1)

In [39]:
negamax_test_driver.run(0) # agent is 1st player
# random_test_driver.run(first=1)  # agent is 2nd player

[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
[0]
[0.]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0.]]
[3]
[0.]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0. -1.  0.  0.  0.]]
[0]
[0.]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0. -1.  0.  0.  0.]]
[4]
[0.]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0. -1. -1.  0.  0.]]
[0]
[0.]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 1.  0.  