<a href="https://colab.research.google.com/github/ericzhangxii/AlphaZero-Chess-Implementation/blob/main/AlphaZero_for_Real.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install tensorflow
!pip install chess



In [2]:
import chess
import tensorflow as tf
import copy
import math
import numpy
from typing import List
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, BatchNormalization, Add, Activation
from tensorflow.keras.models import Model

# Unmodified Helpers

In [5]:

##########################
####### Helpers ##########


class AlphaZeroConfig(object):

  def __init__(self):
    ### Self-Play
    self.num_actors = 5000

    self.num_sampling_moves = 30
    self.max_moves = 512  # for chess and shogi, 722 for Go.
    self.num_simulations = 800

    # Root prior exploration noise.
    self.root_dirichlet_alpha = 0.3  # for chess, 0.03 for Go and 0.15 for shogi.
    self.root_exploration_fraction = 0.25

    # UCB formula
    self.pb_c_base = 19652
    self.pb_c_init = 1.25

    ### Training
    self.training_steps = int(700e3)
    self.checkpoint_interval = int(1e3)
    self.window_size = int(1e6)
    self.batch_size = 4096

    self.weight_decay = 1e-4
    self.momentum = 0.9
    # Schedule for chess and shogi, Go starts at 2e-2 immediately.
    self.learning_rate_schedule = {
        0: 2e-1,
        100e3: 2e-2,
        300e3: 2e-3,
        500e3: 2e-4
    }


class Node(object):

  def __init__(self, prior: float):
    self.visit_count = 0
    self.to_play = -1
    self.prior = prior
    self.value_sum = 0
    self.children = {}

  def expanded(self):
    return len(self.children) > 0

  def value(self):
    if self.visit_count == 0:
      return 0
    return self.value_sum / self.visit_count

class Game(object):
  pass

class ReplayBuffer(object):

  def __init__(self, config: AlphaZeroConfig):
    self.window_size = config.window_size
    self.batch_size = config.batch_size
    self.buffer = []

  def save_game(self, game):
    if len(self.buffer) > self.window_size:
      self.buffer.pop(0)
    self.buffer.append(game)

  def sample_batch(self):
    # Sample uniformly across positions.
    move_sum = float(sum(len(g.history) for g in self.buffer))
    games = numpy.random.choice(
        self.buffer,
        size=self.batch_size,
        p=[len(g.history) / move_sum for g in self.buffer])
    game_pos = [(g, numpy.random.randint(len(g.history))) for g in games]
    return [(g.make_image(i), g.make_target(i)) for (g, i) in game_pos] # returns image, [target value, target policy] for a random game position

class Network(object):
  pass
  
class SharedStorage(object):
  def __init__(self):
    self._networks = {}

  def latest_network(self) -> Network:
    if self._networks:
      return self._networks[max(self._networks.iterkeys())]
    else:
      return make_uniform_network()  # policy -> uniform, value -> 0.5

  def save_network(self, step: int, network: Network):
    self._networks[step] = network


##### End Helpers ########
##########################

# Modified Game Class

In [None]:
class Game(object):

  def __init__(self, history=None):
    self.history = history or []
    self.child_visits = []
    self.num_actions = 4672  # action space size for chess; 11259 for shogi, 362 for Go
    self.board = chess.Board()

  def terminal(self):
    # Game specific termination rules.
    return self.board.is_game_over()

  def terminal_value(self, to_play): #to_play here represents the player in the turn evaluated
    # Game specific value.
    if(self.terminal()):
      if(self.board.is_checkmate()):
        if(to_play == self.to_play): #If the player evaluated is the same as the player in checkmate, return -1
          return -1
        else:
          return 1
      else:
        return 0
    else:
      return None

  def legal_actions(self):
    # Game specific calculation of legal actions.
    return list(self.board.legal_moves)

  def clone(self):
    return copy.deepcopy(self)

  def apply(self, action):
    scratch_board = copy.deepcopy(self.board)
    scratch_board.push(action)
    self.history.append(scratch_board)
    self.board = copy.deepcopy(scratch_board)


  def store_search_statistics(self, root):
    sum_visits = sum(child.visit_count for child in root.children.itervalues())
    self.child_visits.append([
        root.children[a].visit_count / sum_visits if a in root.children else 0
        for a in range(self.num_actions)
    ])

  def make_image(self, state_index: int):
    # Game specific feature planes. Used in two contexts: for evaluating during MCTS and during training for a specific time step. All in Network.inference()
    # The list it's looking for is evidently "8 time steps", counting back from the state index. 
    if(len(self.history)==0):
      return board_to_planes(self.board)
    board = self.history[state_index]
    return board_to_planes(board)

  def make_target(self, state_index: int):
    #returns target weight and target policy, the policy is apparently based on child visits?
    return (self.terminal_value(state_index % 2),
            self.child_visits[state_index])

  def to_play(self):
    return len(self.history) % 2 #white if 0, black if 1

# New Helpers

In [None]:
import chess
import numpy as np

def board_to_planes(board):
    """
    Convert a chess.Board() object into a set of planes representing the board state.
    """
    # Create a tensor with shape (8, 8, 12) to store the piece information
    planes = np.zeros((8, 8, 12), dtype=np.uint8)

    # Iterate through the board and set the corresponding plane values
    for i in range(64):
        piece = board.piece_at(i)
        if piece:
            color = int(piece.color)
            piece_type = piece.piece_type - 1
            row, col = divmod(i, 8)
            planes[row, col, color * 6 + piece_type] = 1

    return planes

def history_to_tensor(history, state_index=-1):
    """
    Convert a history of chess board positions into a tensor.
    """
    T = 8
    M = 12
    L = 7
    N = 8

    # If state_index is -1, use the last 8 board positions
    if state_index == -1:
        state_index = max(0, len(history) - T)

    # Initialize the tensor with shape (8, 8, MT + L)
    tensor = np.zeros((N, N, M * T + L), dtype=np.uint8)

    # Fill in the tensor with the board positions
    for t in range(T):
        if t < len(history) - state_index:
            board = history[state_index + t]
            tensor[:, :, M * t:M * (t + 1)] = board_to_planes(board)

    # Fill in the constant-valued input planes
    last_board = history[-1]
    tensor[:, :, -L] = int(last_board.turn)
    tensor[:, :, -L + 1] = last_board.fullmove_number
    tensor[:, :, -L + 2] = int(last_board.has_kingside_castling_rights(chess.WHITE))
    tensor[:, :, -L + 3] = int(last_board.has_kingside_castling_rights(chess.BLACK))
    tensor[:, :, -L + 4] = int(last_board.has_queenside_castling_rights(chess.WHITE))
    tensor[:, :, -L + 5] = int(last_board.has_queenside_castling_rights(chess.BLACK))
    tensor[:, :, -L + 6] = last_board.halfmove_clock

    return tensor


# Modified Network


In [7]:
class chessNetwork(Network):
    def __init__(self):
        self.model = self.build_chess_model()
        self.model.compile(optimizer='adam', loss=['mean_squared_error', 'categorical_crossentropy'])

    def build_chess_model(self):
        input_shape = (8, 8, 12)
        n_actions = 4672

        input_board = Input(shape=input_shape)

        def residual_block(x):
            res = x
            x = Conv2D(256, 3, padding='same')(x)
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(256, 3, padding='same')(x)
            x = BatchNormalization()(x)
            x = Add()([x, res])
            x = Activation('relu')(x)
            return x

        x = Conv2D(256, 3, padding='same')(input_board)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

        for _ in range(5):
            x = residual_block(x)

        value = Conv2D(1, 1, padding='same')(x)
        value = BatchNormalization()(value)
        value = Activation('relu')(value)
        value = Flatten()(value)
        value = Dense(256, activation='relu')(value)
        value = Dense(1, activation='tanh')(value)

        policy = Conv2D(73, 1, padding='same')(x)
        policy = BatchNormalization()(policy)
        policy = Activation('relu')(policy)
        policy = Flatten()(policy)
        policy = Dense(n_actions, activation='softmax')(policy)

        model = Model(inputs=input_board, outputs=[value, policy])

        return model

    def inference(self, image):
        input_data = np.expand_dims(image, axis=0)
        value, policy = self.model.predict(input_data)
        policy /= policy.sum()
        move_prob_dict = {}
        return value[0], policy[0]

    def get_weights(self):
        return self.model.get_weights()


# Self Play and Training

In [None]:



# AlphaZero training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def alphazero(config: AlphaZeroConfig):
  storage = SharedStorage()
  replay_buffer = ReplayBuffer(config)

  for i in range(config.num_actors):
    launch_job(run_selfplay, config, storage, replay_buffer)

  train_network(config, storage, replay_buffer)

  return storage.latest_network()


##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: AlphaZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
  while True:
    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: AlphaZeroConfig, network: Network):
  game = Game()
  while not game.terminal() and len(game.history) < config.max_moves:
    action, root = run_mcts(config, game, network)
    game.apply(action)
    game.store_search_statistics(root)
  return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: AlphaZeroConfig, game: Game, network: Network):
  root = Node(0)
  evaluate(root, game, network)
  add_exploration_noise(config, root)

  for _ in range(config.num_simulations):
    node = root
    scratch_game = game.clone()
    search_path = [node]

    while node.expanded():
      action, node = select_child(config, node)
      scratch_game.apply(action)
      search_path.append(node)

    value = evaluate(node, scratch_game, network)
    backpropagate(search_path, value, scratch_game.to_play())
  return select_action(config, game, root), root


def select_action(config: AlphaZeroConfig, game: Game, root: Node):
  visit_counts = [(child.visit_count, action)
                  for action, child in root.children.iteritems()]
  if len(game.history) < config.num_sampling_moves:
    _, action = softmax_sample(visit_counts)
  else:
    _, action = max(visit_counts)
  return action


# Select the child with the highest UCB score.
def select_child(config: AlphaZeroConfig, node: Node):
  _, action, child = max((ucb_score(config, node, child), action, child)
                         for action, child in node.children.iteritems())
  return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: AlphaZeroConfig, parent: Node, child: Node):
  pb_c = math.log((parent.visit_count + config.pb_c_base + 1) /
                  config.pb_c_base) + config.pb_c_init
  pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

  prior_score = pb_c * child.prior
  value_score = child.value()
  return prior_score + value_score


# We use the neural network to obtain a value and policy prediction.
def evaluate(node: Node, game: Game, network: Network):
  value, policy_logits = network.inference(game.make_image(-1))

  # Expand the node.
  node.to_play = game.to_play()
  policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}
  policy_sum = sum(policy.itervalues())
  for action, p in policy.iteritems():
    node.children[action] = Node(p / policy_sum)
  return value


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play):
  for node in search_path:
    node.value_sum += value if node.to_play == to_play else (1 - value)
    node.visit_count += 1


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: AlphaZeroConfig, node: Node):
  actions = node.children.keys()
  noise = numpy.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########


def train_network(config: AlphaZeroConfig, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
  network = Network()
  optimizer = tf.train.MomentumOptimizer(config.learning_rate_schedule,
                                         config.momentum)
  for i in range(config.training_steps):
    if i % config.checkpoint_interval == 0:
      storage.save_network(i, network)
    batch = replay_buffer.sample_batch()
    update_weights(optimizer, network, batch, config.weight_decay)
  storage.save_network(config.training_steps, network)


def update_weights(optimizer, network: Network, batch,
                   weight_decay: float):
  loss = 0
  for image, (target_value, target_policy) in batch:
    value, policy_logits = network.inference(image)
    loss += (
        tf.losses.mean_squared_error(value, target_value) +
        tf.nn.softmax_cross_entropy_with_logits(
            logits=policy_logits, labels=target_policy))

  for weights in network.get_weights():
    loss += weight_decay * tf.nn.l2_loss(weights)

  optimizer.minimize(loss)


######### End Training ###########
##################################

################################################################################
############################# End of pseudocode ################################
################################################################################


# Stubs to make the typechecker happy, should not be included in pseudocode
# for the paper.
def softmax_sample(d):
  return 0, 0


def launch_job(f, *args):
  f(*args)


def make_uniform_network():
  return Network()


# Testing


In [None]:
game = Game()
network = Network()
root = Node(0)
value, policy_logits = network.inference(game.make_image(-1))
print(value)
print(len(policy_logits))
evaluate(root, game, network)

[0.0224242]
4672


IndexError: ignored

In [None]:
myGame = Game()
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
myGame.apply(myGame.legal_actions()[0])
print(len(myGame.history))
index = -1
for i in range(index,index-8,-1):
  print(i)



9
-1
-2
-3
-4
-5
-6
-7
-8


In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

# Create an optimizer with the desired parameters.
opt = tf.compat.v1.train.MomentumOptimizer(learning_rate=0.05,momentum=0.9)
var1 = 42
# `loss` is a callable that takes no argument and returns the value
# to minimize.
loss = lambda: var1**2 - 5*var1 + 10

# Initialize a list to store loss values
loss_values = []

# Call minimize to update the list of variables.
for i in range(100):
    opt.minimize(loss)
    # Store the current loss value in the list
    loss_values.append(loss().numpy())

# Plot the loss values
plt.plot(loss_values)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Loss vs. Iteration')
plt.show()


TypeError: ignored

In [None]:
print(loss_values)

[1831.3125, 1484.0754, 1202.8137, 974.9916, 790.4557, 640.9817, 519.90765, 421.8377, 342.40103, 278.0573, 225.93889, 183.72299, 149.52812, 121.830284, 99.39502, 81.222466, 66.5027, 54.579693, 44.92205, 37.09936, 30.762985, 25.63052, 21.47322, 18.105808, 15.378206, 13.168848, 11.379267, 9.929707, 8.755562, 7.8045063, 7.034151, 6.41016, 5.90473, 5.495331, 5.163719, 4.895111, 4.67754, 4.5013094, 4.3585596, 4.2429333, 4.149276, 4.073414, 4.011965, 3.9621916, 3.921875, 3.8892183, 3.8627672, 3.841342, 3.8239865, 3.8099294, 3.798543, 3.7893195, 3.7818484, 3.7757978, 3.7708964, 3.7669253, 3.76371, 3.761105, 3.758995, 3.757286, 3.7559013, 3.7547808, 3.753872, 3.7531366, 3.75254, 3.752058, 3.7516665, 3.7513504, 3.7510934, 3.750886, 3.7507172, 3.7505808, 3.7504706, 3.7503815, 3.7503095, 3.7502499, 3.7502031, 3.7501645, 3.750133, 3.7501078, 3.7500868, 3.750071, 3.7500572, 3.7500463, 3.7500372, 3.7500305, 3.7500248, 3.7500205, 3.7500162, 3.7500134, 3.7500105, 3.750008, 3.7500072, 3.7500052, 3.75000

In [None]:
print(var1, var2)

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.000104857536> <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.012093236>


In [None]:
tensor = numpy.zeros((8,8,73),dtype=numpy.float32)

In [3]:
board = chess.Board()

print(board.)