<a href="https://colab.research.google.com/github/lunathanael/chessnn/blob/main/zero_policy_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [328]:
!pip install chess
import chess
import math
import numpy as np
import tensorflow as tf
from typing import List
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, Dense, Flatten
from tensorflow.keras.models import Model



In [329]:
def residual_block(x, filters):
    """Create a residual block."""
    y = Conv2D(filters, kernel_size=3, padding='same')(x)
    y = BatchNormalization()(y)
    y = ReLU()(y)
    y = Conv2D(filters, kernel_size=3, padding='same')(y)
    y = BatchNormalization()(y)
    y = Add()([y, x])
    y = ReLU()(y)
    return y

def make_network():
    # Input layer
    input_layer = Input(shape=(8, 8, 20))  # 8x8 grid with 20 features per cell

    # Body
    x = Conv2D(256, kernel_size=3, padding='same', activation='relu')(input_layer)
    x = BatchNormalization()(x)
    # 19 residual blocks
    for _ in range(19):
        x = residual_block(x, 256)  # Each block has two convolutional layers with 256 filters, kernel size 3x3

    # Policy Head
    policy_head = Conv2D(256, kernel_size=3, padding='same', activation='relu')(x)
    policy_head = BatchNormalization()(policy_head)
    policy_head = Conv2D(73, kernel_size=1)(policy_head)  # Output 73 policies

    # Value Head
    value_head = Conv2D(1, kernel_size=1, activation='relu')(x)
    value_head = BatchNormalization()(value_head)
    value_head = Flatten()(value_head)
    value_head = Dense(256, activation='relu')(value_head)
    value_head = Dense(1, activation='tanh')(value_head)  # Output single value

    # Create the model
    model = tf.keras.Model(inputs=input_layer, outputs=[value_head, policy_head])
    return model


def ConstantPolicyHead(shape):
    # Custom layer for policy head
    class CPHead(tf.keras.layers.Layer):
        def call(self, inputs):
            return tf.constant(1/73, shape=shape)

    return CPHead()

def ConstantValueHead():
    # Custom layer for value head
    class CVHead(tf.keras.layers.Layer):
        def call(self, inputs):
            return tf.constant(0.5, shape=(1,))

    return CVHead()

def make_uniform_network():
    # Input layer
    input_layer = Input(shape=(8, 8, 20))

    # Body
    x = Conv2D(256, kernel_size=3, padding='same')(input_layer)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    for _ in range(19):
        x = residual_block(x, 256)

    # Modified Policy Head for uniform output
    policy_head = ConstantPolicyHead((8, 8, 73))(x)

    # Modified Value Head for constant output
    value_head = ConstantValueHead()(x)

    # Create the model with constant outputs
    model = Model(inputs=input_layer, outputs=[value_head, policy_head])
    return model


Helpers

In [330]:
class Config(object):

  def __init__(self):
    ### Self-Play
    self.num_actors = 5000

    self.num_sampling_moves = 30
    self.max_moves = 512  # for chess and shogi, 722 for Go.
    self.num_simulations = 800

    # Root prior exploration noise.
    self.root_dirichlet_alpha = 0.3  # for chess, 0.03 for Go and 0.15 for shogi.
    self.root_exploration_fraction = 0.25

    # UCB formula
    self.pb_c_base = 19652
    self.pb_c_init = 1.25

    ### Training
    self.training_steps = int(700e3)
    self.checkpoint_interval = int(1e3)
    self.window_size = int(1e6)
    self.batch_size = 4096

    self.weight_decay = 1e-4
    self.momentum = 0.9
    # Schedule for chess and shogi, Go starts at 2e-2 immediately.
    self.learning_rate_schedule = {
        0: 2e-1,
        100e3: 2e-2,
        300e3: 2e-3,
        500e3: 2e-4
    }

In [331]:
class Node(object):

  def __init__(self, prior: float):
    self.visit_count = 0
    self.to_play = -1
    self.prior = prior
    self.value_sum = 0
    self.children = {}

  def expanded(self):
    return len(self.children) > 0

  def value(self):
    if self.visit_count == 0:
      return 0
    return self.value_sum / self.visit_count


In [332]:
def fen_to_repr(fen, repeats):
    # Split the FEN string to get the relevant parts
    parts = fen.split(' ')
    board_fen, player, castling, _, halfmove, fullmove = parts[:6]

    # 12 pieces, 4 cf, col, rep, half, total
    board = np.zeros((8, 8, 12 + 4 + 1 + 1 + 1 + 1), dtype=float)

    # Define piece order and mapping to layers
    piece_map = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
                 'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}

    # Fill the board with pieces
    row = 0
    col = 0
    for char in board_fen:
        if char.isdigit():
            col += int(char)
        elif char == '/':
            row += 1
            col = 0
        else:
            board[row, col, piece_map[char]] = 1
            col += 1

    # Castling rights encoded in four binary planes
    board[:, :, 12] = 1 if 'K' in castling else 0  # White kingside
    board[:, :, 13] = 1 if 'Q' in castling else 0  # White queenside
    board[:, :, 14] = 1 if 'k' in castling else 0  # Black kingside
    board[:, :, 15] = 1 if 'q' in castling else 0  # Black queenside

    # Player color (1 for black, 0 for white)
    board[:, :, 16] = 1 if player == 'b' else 0

    # Position repetitions
    board[:, :, 17] = repeats

    board[:, :, 18] = float(halfmove)

    # Move number (as a real value)
    board[:, :, 19] = float(fullmove)

    return board

In [333]:
def square_to_coord(square):
  #       r           c
  return (square // 8, square % 8)

piece_map = {
        'P' : 0, 'N' : 1, 'B' : 2, 'R' : 3, 'Q' : 4, 'K' : 5,
        'p' : 6, 'n' : 7, 'b' : 8, 'r' : 9, 'q' : 10, 'k' : 11,
}


class Environment(object):

  def __init__(self, env=None):
    if env == None:
      self.board = chess.Board(chess.STARTING_FEN)
      self.board_history = []
      self.repetitions = {}

      self.update_history()
    else:
      self.board = chess.copy(env.board)
      self.board_history = env.board_history.copy()
      self.repetitions = env.repetitions.copy()

  def is_terminal(self):
    return self.board.is_game_over()

  def terminal_value(self, to_play):
    if self.board.is_checkmate():
      color = self.board.outcome.winner
      color = (color == chess.BLACK)
      return color == to_play
    else:
      return 0.5

  def generate_legal_moves(self):
    return self.board.generate_legal_moves()

  def encode_action(self, move): #12 piece types, 4 flags, 2 coords
    action = np.zeros(shape=(8, 8, 73), dtype=int)

    piece = self.board.piece_at(move.from_square)
    piece = piece_map[str(piece)]
    fr, fc = square_to_coord(move.from_square)
    tr, tc = square_to_coord(move.to_square)

    action[:,:,piece] = 1

    action[fr, fc, 12] = 1
    action[tr, tc, 13] = 1

    if move.promotion:
      pp = move.promotion + 12 # 5,4,3,2
      action[:,:, pp] = 1

    return action

  def generate_legal_actions(self):
    moves = self.generate_legal_moves()
    actions = [self.encode_action(move) for move in moves]
    return np.array(actions)

  def update_history(self):
    fen = self.board.fen()
    s_fen = fen.split(' ')
    cs_fen = ' '.join(s_fen[:-2])

    if cs_fen in self.repetitions:
      self.repetitions[cs_fen] += 1
      repeats = self.repetitions[cs_fen]
    else:
      self.repetitions[cs_fen] = 1
      repeats = 1

    halfmove_clock = self.board.halfmove_clock
    self.board_history.append((fen, repeats))

  def make_image(self, state_index=-1):
    fen, repeats= self.board_history[state_index]

    repr = fen_to_repr(fen, repeats)
    return repr

In [334]:
class Game(object):

  def __init__(self, history=None, environment=None):
    self.history = history or []
    self.child_visits = []
    self.num_actions = 4672  # action space size for chess
    self.environment = Environment(environment)

  def terminal(self):
    # Game specific termination rules.
    return self.environment.is_terminal()

  def terminal_value(self, to_play):
    # Game specific value.
    return self.environment.terminal_value()

  def legal_actions(self): #TODO
    # Game specific calculation of legal actions.
    return self.environment.generate_legal_actions()

  def clone(self):
    return Game(list(self.history), self.environment)

  def apply(self, action):
    self.history.append(action)

  def store_search_statistics(self, root):
    sum_visits = sum(child.visit_count for child in root.children.itervalues())
    self.child_visits.append([
        root.children[a].visit_count / sum_visits if a in root.children else 0
        for a in range(self.num_actions)
    ])

  def make_image(self, state_index: int): #TODO
    # Game specific feature planes.
    return self.environment.make_image(state_index)

  def make_target(self, state_index: int):
    return (self.terminal_value(state_index % 2),
            self.child_visits[state_index])

  def to_play(self):
    return len(self.history) % 2


In [335]:
class ReplayBuffer(object):

  def __init__(self, config: Config):
    self.window_size = config.window_size
    self.batch_size = config.batch_size
    self.buffer = []

  def save_game(self, game):
    if len(self.buffer) > self.window_size:
      self.buffer.pop(0)
    self.buffer.append(game)

  def sample_batch(self):
    # Sample uniformly across positions.
    move_sum = float(sum(len(g.history) for g in self.buffer))
    games = np.random.choice(
        self.buffer,
        size=self.batch_size,
        p=[len(g.history) / move_sum for g in self.buffer])
    game_pos = [(g, np.random.randint(len(g.history))) for g in games]
    return [(g.make_image(i), g.make_target(i)) for (g, i) in game_pos]

In [343]:
class Network(object):

  def __init__(self, uniform_model: bool=False):
    if uniform_model:
      self.model = make_uniform_network()
    else:
      self.model = make_network()

  def inference(self, image):
      # Run the neural network model to get predictions
      value, policy_logits = self.model.predict(np.array([image]))

      value = value[0] # The value output is a scalar representing the predicted game outcome

      return value, np.array(policy_logits)


  def get_weights(self): #TODO
    # Returns the weights of this network.
    return self.model.get_weights()

In [337]:
class SharedStorage(object):

  def __init__(self):
    self._networks = {}

  def latest_network(self) -> Network:
    if self._networks:
      return self._networks[max(self._networks.iterkeys())]
    else:
      return Network(True)  # policy -> uniform, value -> 0.5

  def save_network(self, step: int, network: Network):
    self._networks[step] = network


In [338]:
# AlphaZero training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def zero(config: Config):
  storage = SharedStorage()
  replay_buffer = ReplayBuffer(config)

  for i in range(config.num_actors):
    launch_job(run_selfplay, config, storage, replay_buffer)

  train_network(config, storage, replay_buffer)

  return storage.latest_network()



In [339]:
##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: Config, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
  while True:
    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: Config, network: Network):
  game = Game()
  while not game.terminal() and len(game.history) < config.max_moves:
    action, root = run_mcts(config, game, network)
    game.apply(action)
    game.store_search_statistics(root)
  return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: AlphaZeroConfig, game: Game, network: Network):
  root = Node(0)
  evaluate(root, game, network)
  add_exploration_noise(config, root)

  for _ in range(config.num_simulations):
    node = root
    scratch_game = game.clone()
    search_path = [node]

    while node.expanded():
      action, node = select_child(config, node)
      scratch_game.apply(action)
      search_path.append(node)

    value = evaluate(node, scratch_game, network)
    backpropagate(search_path, value, scratch_game.to_play())
  return select_action(config, game, root), root



def softmax_sample(visit_counts, temperature=10.0):
    counts, actions = zip(*visit_counts)
    # Apply softmax with temperature
    counts = np.array(counts)
    counts = counts / temperature  # Apply temperature scaling
    softmax_probs = np.exp(counts) / sum(np.exp(counts))
    # Sample an action based on the softmax probabilities
    action = np.random.choice(actions, p=softmax_probs)
    return softmax_probs, action


def select_action(config: AlphaZeroConfig, game: Game, root: Node):
  visit_counts = [(child.visit_count, action)
                  for action, child in root.children.iteritems()]
  if len(game.history) < config.num_sampling_moves:
    _, action = softmax_sample(visit_counts)
  else:
    _, action = max(visit_counts)
  return action


# Select the child with the highest UCB score.
def select_child(config: AlphaZeroConfig, node: Node):
  _, action, child = max((ucb_score(config, node, child), action, child)
                         for action, child in node.children.iteritems())
  return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: AlphaZeroConfig, parent: Node, child: Node):
  pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                  config.pb_c_base) + config.pb_c_init
  pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

  prior_score = pb_c * child.prior
  value_score = child.value()
  return prior_score + value_score


# We use the neural network to obtain a value and policy prediction.
def evaluate(node: Node, game: Game, network: Network):
  value, policy_logits = network.inference(game.make_image(-1))

  # Expand the node.
  node.to_play = game.to_play()

  policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}
  policy_sum = sum(policy.itervalues())
  for action, p in policy.iteritems():
    node.children[action] = Node(p / policy_sum)
  return value


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play):
  for node in search_path:
    node.value_sum += value if node.to_play == to_play else (1 - value)
    node.visit_count += 1


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: AlphaZeroConfig, node: Node):
  actions = node.children.keys()
  noise = numpy.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########


def train_network(config: AlphaZeroConfig, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
  network = Network()
  optimizer = tf.keras.optimizers.SGD(config.learning_rate_schedule,
                                         config.momentum)
  for i in range(config.training_steps):
    if i % config.checkpoint_interval == 0:
      storage.save_network(i, network)
    batch = replay_buffer.sample_batch()
    update_weights(optimizer, network, batch, config.weight_decay)
  storage.save_network(config.training_steps, network)

def update_weights(optimizer: tf.keras.optimizers, network: Network, batch,
                   weight_decay: float):
  loss = 0
  for image, (target_value, target_policy) in batch:
    value, policy_logits = network.inference(image)
    loss += (
        tf.losses.mean_squared_error(value, target_value) +
        tf.nn.softmax_cross_entropy_with_logits(
            logits=policy_logits, labels=target_policy))

  for weights in network.get_weights():
    loss += weight_decay * tf.nn.l2_loss(weights)

  optimizer.minimize(loss)


######### End Training ###########
##################################


def launch_job(f, *args):
  f(*args)

In [344]:
config = Config()
network_1 = zero(config)



TypeError: only size-1 arrays can be converted to Python scalars

In [345]:
  gs = Game()
  nn = Network(True)

  value, policy_logits = nn.inference(gs.make_image(-1))



In [352]:
for a in gs.legal_actions():
  print((policy_logits(a)).shape)

TypeError: 'numpy.ndarray' object is not callable