<a href="https://colab.research.google.com/github/lunathanael/chessnn/blob/main/policy_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [130]:
import collections
import math
import typing
from typing import Dict, List, Optional
import numpy
import tensorflow as tf
from typing import List
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, Dense

In [131]:
MAXIMUM_FLOAT_VALUE = float('inf')

KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])

In [132]:
class MinMaxStats(object):
  """A class that holds the min-max values of the tree."""

  def __init__(self, known_bounds: Optional[KnownBounds]):
    self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
    self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE

  def update(self, value: float):
    self.maximum = max(self.maximum, value)
    self.minimum = min(self.minimum, value)

  def normalize(self, value: float) -> float:
    if self.maximum > self.minimum:
      # We normalize only when we have set the maximum and minimum values.
      return (value - self.minimum) / (self.maximum - self.minimum)
    return value

In [133]:
class Config(object):


  def __init__(self,
               action_space_size: int,
               max_moves: int,
               discount: float,
               dirichlet_alpha: float,
               num_simulations: int,
               batch_size: int,
               td_steps: int,
               num_actors: int,
               lr_init: float,
               lr_decay_steps: float,
               visit_softmax_temperature_fn,
               known_bounds: Optional[KnownBounds] = None):
    ### Self-Play
    self.action_space_size = action_space_size
    self.num_actors = num_actors

    self.visit_softmax_temperature_fn = visit_softmax_temperature_fn
    self.max_moves = max_moves
    self.num_simulations = num_simulations
    self.discount = discount

    # Root prior exploration noise.
    self.root_dirichlet_alpha = dirichlet_alpha
    self.root_exploration_fraction = 0.25

    # UCB formula
    self.pb_c_base = 19652
    self.pb_c_init = 1.25

    # If we already have some information about which values occur in the
    # environment, we can use them to initialize the rescaling.
    # This is not strictly necessary, but establishes identical behaviour to
    # AlphaZero in board games.
    self.known_bounds = known_bounds

    ### Training
    self.training_steps = int(1000e3)
    self.checkpoint_interval = int(1e3)
    self.window_size = int(1e6)
    self.batch_size = batch_size
    self.num_unroll_steps = 5
    self.td_steps = td_steps

    self.weight_decay = 1e-4
    self.momentum = 0.9

    # Exponential learning rate schedule
    self.lr_init = lr_init
    self.lr_decay_rate = 0.1
    self.lr_decay_steps = lr_decay_steps

  def new_game(self):
    return Game(self.action_space_size, self.discount)

In [134]:
class ChessNetwork(tf.keras.Model):
    def __init__(self, config: Config, num_res_blocks=16, num_filters=256):
        super(ChessNetwork, self).__init__()
        self.action_size = config.action_space_size
        self.num_res_blocks = num_res_blocks
        self.num_filters = num_filters

        # Representation Function
        self.representation_conv = tf.keras.layers.Conv2D(num_filters, 3, padding='same', activation='relu')
        self.representation_res_blocks = [self._build_residual_block() for _ in range(num_res_blocks)]

        # Dynamics Function
        self.dynamics_action_conv = tf.keras.layers.Conv2D(num_filters, 3, padding='same', activation='relu')
        self.dynamics_res_blocks = [self._build_residual_block() for _ in range(num_res_blocks)]
        self.dynamics_reward = tf.keras.layers.Dense(1, activation='tanh')

        # Prediction Function
        self.prediction_policy = tf.keras.layers.Dense(self.action_size)  # policy logits
        self.prediction_value = tf.keras.layers.Dense(1, activation='tanh')  # value

    def _build_residual_block(self):
        return tf.keras.Sequential([
            tf.keras.layers.Conv2D(self.num_filters, 3, padding='same', activation='relu'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(self.num_filters, 3, padding='same', activation='relu'),
            tf.keras.layers.BatchNormalization()
        ])

    def call(self, inputs, action=None):
      # Process inputs through the representation function
      if action is None:
        hidden_state = self.representation_conv(inputs)
        for block in self.representation_res_blocks:
            hidden_state = block(hidden_state)
        reward = tf.constant(0.0)
      else:
        hidden_state = tf.concat([inputs, action], axis=-1)
        hidden_state = self.dynamics_action_conv(hidden_state)
        for block in self.dynamics_res_blocks:
            hidden_state = block(hidden_state)

      reward = self.dynamics_reward(hidden_state)
      policy_logits = self.prediction_policy(hidden_state)
      value = self.prediction_value(hidden_state)

      return NetworkOutput(value, reward, policy_logits, hidden_state)


class UniformChessNetwork(tf.keras.Model):
    def __init__(self, config: Config, num_res_blocks=16, num_filters=256):
        super(UniformChessNetwork, self).__init__()
        self.action_size = config.action_space_size
        self.num_res_blocks = num_res_blocks
        self.num_filters = num_filters

        # Representation Function
        self.representation_conv = tf.keras.layers.Conv2D(num_filters, 3, padding='same', activation='relu')
        self.representation_res_blocks = [self._build_residual_block() for _ in range(num_res_blocks)]

        # Dynamics Function
        self.dynamics_action_conv = tf.keras.layers.Conv2D(num_filters, 3, padding='same', activation='relu')
        self.dynamics_res_blocks = [self._build_residual_block() for _ in range(num_res_blocks)]

    def _build_residual_block(self):
        return tf.keras.Sequential([
            tf.keras.layers.Conv2D(self.num_filters, 3, padding='same', activation='relu'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Conv2D(self.num_filters, 3, padding='same', activation='relu'),
            tf.keras.layers.BatchNormalization()
        ])

    def call(self, inputs):
        # Implement call method to process inputs through representation and dynamics functions
        # Process inputs through the representation function
        hidden_state = self.representation_conv(inputs)
        for block in self.representation_res_blocks:
            hidden_state = block(hidden_state)

        # Uniform policy logits and fixed values for reward and value
        uniform_policy_logits = tf.fill([self.action_size], tf.math.log(1.0 / self.action_size))
        value = tf.constant(0.0)
        reward = tf.constant(0.0)

        return NetworkOutput(value, reward, uniform_policy_logits, hidden_state)



In [135]:
def make_board_game_config(action_space_size: int, max_moves: int,
                           dirichlet_alpha: float,
                           lr_init: float) -> Config:

  def visit_softmax_temperature(num_moves, training_steps):
    if num_moves < 30:
      return 1.0
    else:
      return 0.0  # Play according to the max.

  return Config(
      action_space_size=action_space_size,
      max_moves=max_moves,
      discount=1.0,
      dirichlet_alpha=dirichlet_alpha,
      num_simulations=800,
      batch_size=2048,
      td_steps=max_moves,  # Always use Monte Carlo return.
      num_actors=3000,
      lr_init=lr_init,
      lr_decay_steps=400e3,
      visit_softmax_temperature_fn=visit_softmax_temperature,
      known_bounds=KnownBounds(-1, 1))

In [136]:
def make_chess_config() -> Config:
  return make_board_game_config(
      action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1)

In [137]:
class Action(object):

  def __init__(self, index: int):
    self.index = index

  def __hash__(self):
    return self.index

  def __eq__(self, other):
    return self.index == other.index

  def __gt__(self, other):
    return self.index > other.index

class Player(object):
  pass

Helpers

In [138]:
class Node(object):

  def __init__(self, prior: float):
    self.visit_count = 0
    self.to_play = -1
    self.prior = prior
    self.value_sum = 0
    self.children = {}
    self.hidden_state = None
    self.reward = 0

  def expanded(self) -> bool:
    return len(self.children) > 0

  def value(self) -> float:
    if self.visit_count == 0:
      return 0
    return self.value_sum / self.visit_count

In [139]:
class ActionHistory(object):
  """Simple history container used inside the search.

  Only used to keep track of the actions executed.
  """

  def __init__(self, history: List[Action], action_space_size: int):
    self.history = list(history)
    self.action_space_size = action_space_size

  def clone(self):
    return ActionHistory(self.history, self.action_space_size)

  def add_action(self, action: Action):
    self.history.append(action)

  def last_action(self) -> Action:
    return self.history[-1]

  def action_space(self) -> List[Action]:
    return [Action(i) for i in range(self.action_space_size)]

  def to_play(self):
    return len(self.history) % 2

In [140]:
class Environment(object):
  """The environment that interacting with."""

  def step(self, action):
    pass

  def is_terminal(self):
    pass

  def get_legal_actions(self):
    pass

  def get_image(self, state_index: int):
    pass

In [141]:
class Game(object):
  """A single episode of interaction with the environment."""

  def __init__(self, action_space_size: int, discount: float):
    self.environment = Environment()  # Game specific environment.
    self.history = []
    self.rewards = []
    self.child_visits = []
    self.root_values = []
    self.action_space_size = action_space_size
    self.discount = discount

  def terminal(self) -> bool:
    # Game specific termination rules.
    return self.environment.is_terminal()

  def legal_actions(self) -> List[Action]:
    # Game specific calculation of legal actions.
    return self.environment.get_legal_actions()

  def apply(self, action: Action):
    reward = self.environment.step(action)
    self.rewards.append(reward)
    self.history.append(action)

  def store_search_statistics(self, root: Node):
    sum_visits = sum(child.visit_count for child in root.children.values())
    action_space = (Action(index) for index in range(self.action_space_size))
    self.child_visits.append([
        root.children[a].visit_count / sum_visits if a in root.children else 0
        for a in action_space
    ])
    self.root_values.append(root.value())

  def make_image(self, state_index: int):
    # Game specific feature planes.
    return enviroment.get_image(state_index)

  def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int,
                  to_play):
    # The value target is the discounted root value of the search tree N steps
    # into the future, plus the discounted sum of all rewards until then.
    targets = []
    for current_index in range(state_index, state_index + num_unroll_steps + 1):
      bootstrap_index = current_index + td_steps
      if bootstrap_index < len(self.root_values):
        value = self.root_values[bootstrap_index] * self.discount**td_steps
      else:
        value = 0

      for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
        value += reward * self.discount**i  # pytype: disable=unsupported-operands

      if current_index < len(self.root_values):
        targets.append((value, self.rewards[current_index],
                        self.child_visits[current_index]))
      else:
        # States past the end of games are treated as absorbing states.
        targets.append((0, 0, []))
    return targets

  def to_play(self):
    return len(self.history) % 2

  def action_history(self) -> ActionHistory:
    return ActionHistory(self.history, self.action_space_size)

In [142]:
class ReplayBuffer(object):

  def __init__(self, config: Config):
    self.window_size = config.window_size
    self.batch_size = config.batch_size
    self.buffer = []

  def save_game(self, game):
    if len(self.buffer) > self.window_size:
      self.buffer.pop(0)
    self.buffer.append(game)

  def sample_batch(self, num_unroll_steps: int, td_steps: int):
    games = [self.sample_game() for _ in range(self.batch_size)]
    game_pos = [(g, self.sample_position(g)) for g in games]
    return [(g.make_image(i), g.history[i:i + num_unroll_steps],
             g.make_target(i, num_unroll_steps, td_steps, g.to_play()))
            for (g, i) in game_pos]

  def sample_game(self) -> Game:
    # Sample game from buffer either uniformly or according to some priority.
    return self.buffer[0]

  def sample_position(self, game) -> int:
    # Sample position from game either uniformly or according to some priority.
    return -1

In [143]:
class NetworkOutput(typing.NamedTuple):
  value: float
  reward: float
  policy_logits: Dict[Action, float]
  hidden_state: List[float]

In [144]:
class Network(object):

  def __init__(self, config: Config, uniform_network: bool=False):
    self.steps = 0
    if(uniform_network):
      self.model=UniformChessNetwork(config)
    else:
      self.model=ChessNetwork(config)

  def initial_inference(self, image) -> NetworkOutput:
    # representation + prediction function

    # Use the prediction function to obtain policy logits, value
    return self.model(image)


  def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
    # dynamics + prediction function
    # Apply the dynamics function to get the next hidden state and reward
    return self.model(hidden_state, action)


  def get_weights(self):
    # Returns the weights of this network.
      return self.model.get_weights()

  def training_steps(self) -> int:
    # How many steps / batches the network has been trained for.
    return self.steps

In [145]:
class SharedStorage(object):

  def __init__(self, config: Config):
    self._networks = {}
    self._config = config

  def latest_network(self) -> Network:
    if self._networks:
      return self._networks[max(self._networks.keys())]
    else:
      # policy -> uniform, value -> 0, reward -> 0
      return Network(config, True)

  def save_network(self, step: int, network: Network):
    self._networks[step] = network

In [146]:
#split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def zero(config: Config):
  storage = SharedStorage(config)
  replay_buffer = ReplayBuffer(config)

  for _ in range(config.num_actors):
    launch_job(run_selfplay, config, storage, replay_buffer)

  train_network(config, storage, replay_buffer)

  return storage.latest_network()

In [147]:
##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: Config, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
  while True:
    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: Config, network: Network) -> Game:
  game = config.new_game()

  while not game.terminal() and len(game.history) < config.max_moves:
    # At the root of the search tree we use the representation function to
    # obtain a hidden state given the current observation.
    root = Node(0)
    current_observation = game.make_image(-1)
    expand_node(root, game.to_play(), game.legal_actions(),
                network.initial_inference(current_observation))
    add_exploration_noise(config, root)

    # We then run a Monte Carlo Tree Search using only action sequences and the
    # model learned by the network.
    run_mcts(config, root, game.action_history(), network)
    action = select_action(config, len(game.history), root, network)
    game.apply(action)
    game.store_search_statistics(root)
  return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: Config, root: Node, action_history: ActionHistory,
             network: Network):
  min_max_stats = MinMaxStats(config.known_bounds)

  for _ in range(config.num_simulations):
    history = action_history.clone()
    node = root
    search_path = [node]

    while node.expanded():
      action, node = select_child(config, node, min_max_stats)
      history.add_action(action)
      search_path.append(node)

    # Inside the search tree we use the dynamics function to obtain the next
    # hidden state given an action and the previous hidden state.
    parent = search_path[-2]
    network_output = network.recurrent_inference(parent.hidden_state,
                                                 history.last_action())
    expand_node(node, history.to_play(), history.action_space(), network_output)

    backpropagate(search_path, network_output.value, history.to_play(),
                  config.discount, min_max_stats)



def softmax_sample(visit_counts, temperature=10.0):
    counts, actions = zip(*visit_counts)
    # Apply softmax with temperature
    counts = np.array(counts)
    counts = counts / temperature  # Apply temperature scaling
    softmax_probs = np.exp(counts) / sum(np.exp(counts))
    # Sample an action based on the softmax probabilities
    action = np.random.choice(actions, p=softmax_probs)
    return softmax_probs, action


def select_action(config: Config, num_moves: int, node: Node,
                  network: Network):
  visit_counts = [
      (child.visit_count, action) for action, child in node.children.items()
  ]
  t = config.visit_softmax_temperature_fn(
      num_moves=num_moves, training_steps=network.training_steps())
  _, action = softmax_sample(visit_counts, t)
  return action


# Select the child with the highest UCB score.
def select_child(config: Config, node: Node,
                 min_max_stats: MinMaxStats):
  _, action, child = max(
      (ucb_score(config, node, child, min_max_stats), action,
       child) for action, child in node.children.items())
  return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: Config, parent: Node, child: Node,
              min_max_stats: MinMaxStats) -> float:
  pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                  config.pb_c_base) + config.pb_c_init
  pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

  prior_score = pb_c * child.prior
  value_score = min_max_stats.normalize(child.value())
  return prior_score + value_score

# We expand a node using the value, reward and policy prediction obtained from
# the neural network.
def expand_node(node: Node, to_play, actions: List[Action],
                network_output: NetworkOutput):
  node.to_play = to_play
  node.hidden_state = network_output.hidden_state
  node.reward = network_output.reward
  policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
  policy_sum = sum(policy.values())
  for action, p in policy.items():
    node.children[action] = Node(p / policy_sum)


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play,
                  discount: float, min_max_stats: MinMaxStats):
  for node in search_path:
    node.value_sum += value if node.to_play == to_play else (1-value)
    node.visit_count += 1
    min_max_stats.update(node.value())

    value = node.reward + discount * value


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: Config, node: Node):
  actions = list(node.children.keys())
  noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########


def train_network(config: Config, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
  network = Network(config)
  learning_rate = config.lr_init * config.lr_decay_rate**(
      tf.train.get_global_step() / config.lr_decay_steps)
  optimizer = tf.train.SGD(learning_rate, config.momentum)

  for i in range(config.training_steps):
    if i % config.checkpoint_interval == 0:
      storage.save_network(i, network)
    batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
    update_weights(optimizer, network, batch, config.weight_decay)
    network.steps += 1
  storage.save_network(config.training_steps, network)

def update_weights(optimizer: tf.keras.optimizers, network: Network, batch,
                   weight_decay: float):
  loss = 0
  for image, actions, targets in batch:
    # Initial step, from the real observation.
    value, reward, policy_logits, hidden_state = network.initial_inference(
        image)
    predictions = [(1.0, value, reward, policy_logits)]

    # Recurrent steps, from action and previous hidden state.
    for action in actions:
      value, reward, policy_logits, hidden_state = network.recurrent_inference(
          hidden_state, action)
      predictions.append((1.0 / len(actions), value, reward, policy_logits))

      hidden_state = tf.scale_gradient(hidden_state, 0.5)

    for prediction, target in zip(predictions, targets):
      gradient_scale, value, reward, policy_logits = prediction
      target_value, target_reward, target_policy = target

      l = (
          scalar_loss(value, target_value) +
          scalar_loss(reward, target_reward) +
          tf.nn.softmax_cross_entropy_with_logits(
              logits=policy_logits, labels=target_policy))

      loss += tf.scale_gradient(l, gradient_scale)

  for weights in network.get_weights():
    loss += weight_decay * tf.nn.l2_loss(weights)

  optimizer.minimize(loss)


######### End Training ###########
##################################

def scalar_loss(prediction, target) -> float:
    squared_error = (prediction - target) ** 2
    mse = squared_error.mean()
    return mse

def launch_job(f, *args):
  f(*args)




In [148]:
config = make_chess_config()
network_1 = zero(config)

ValueError: Exception encountered when calling layer 'uniform_chess_network_6' (type UniformChessNetwork).

Layer "conv2d_396" expects 1 input(s), but it received 0 input tensors. Inputs received: []

Call arguments received by layer 'uniform_chess_network_6' (type UniformChessNetwork):
  • inputs=[]