<a href="https://colab.research.google.com/github/lunathanael/chessnn/blob/main/zero_policy_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install chess
import chess
import math
import numpy as np
import tensorflow as tf
from typing import List
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, Dense, Flatten
from tensorflow.keras.models import Model


from threading import Thread, Event

!pip install cairosvg
import chess.svg
import IPython
from IPython.display import SVG, display

import cairosvg
from PIL import Image

import pickle
from google.colab import files



In [23]:
def residual_block(x, filters):
    """Create a residual block."""
    y = Conv2D(filters, kernel_size=3, padding='same')(x)
    y = BatchNormalization()(y)
    y = ReLU()(y)
    y = Conv2D(filters, kernel_size=3, padding='same')(y)
    y = BatchNormalization()(y)
    y = Add()([y, x])
    y = ReLU()(y)
    return y

def make_network():
    # Input layer
    input_layer = Input(shape=(8, 8, 20))  # 8x8 grid with 20 features per cell

    # Body
    x = Conv2D(256, kernel_size=3, padding='same', activation='relu')(input_layer)
    x = BatchNormalization()(x)
    # 19 residual blocks
    for _ in range(19):
        x = residual_block(x, 256)  # Each block has two convolutional layers with 256 filters, kernel size 3x3

    # Policy Head
    policy_head = Conv2D(256, kernel_size=3, padding='same', activation='relu')(x)
    policy_head = BatchNormalization()(policy_head)
    policy_head = Conv2D(73, kernel_size=1)(policy_head)  # Output 73 policies

    # Value Head
    value_head = Conv2D(1, kernel_size=1, activation='relu')(x)
    value_head = BatchNormalization()(value_head)
    value_head = Flatten()(value_head)
    value_head = Dense(256, activation='relu')(value_head)
    value_head = Dense(1, activation='tanh')(value_head)  # Output single value

    # Create the model
    model = tf.keras.Model(inputs=input_layer, outputs=[value_head, policy_head])
    return model



class make_uniform_network():
  def __init__(self):
    self._sval =  np.full((1,1,), 0.5)
    self._spolicy = array2 = np.full((1, 8, 8, 73), 1/73)

    self._val =  np.full((1,1,), 0.5)
    self._policy = array2 = np.full((1, 8, 8, 73), 1/73)


  def predict(self, input, verbose=0):
    repeat_times = input.shape[0]
    if repeat_times != 1:
      # Repeat _val and _policy
      repeated_val = np.repeat(self._val, repeat_times, axis=0)
      repeated_policy = np.repeat(self._policy, repeat_times, axis=0)

      return repeated_val, repeated_policy
    return self._sval, self._spolicy

Helpers

In [4]:
class Config(object):

  def __init__(self):
    ### Self-Play
    self.num_actors = 5 # not enough ram <- 5000

    self.num_sampling_moves = 30
    self.max_moves = 512  # for chess and shogi, 722 for Go.

    self.num_simulations = 100 #was 800, testing with 31

    self.random_action = False # pick random node to explore
    self.uniform_num_simulations = 800 # avg number of legal moves is 31
    self.uniform_num_sampling_moves = 30 # all are sampled from softmax
    self.uniform_softmax_temperature = 10

    # Root prior exploration noise.
    self.root_dirichlet_alpha = 0.3  # for chess, 0.03 for Go and 0.15 for shogi.
    self.root_exploration_fraction = 0.25

    # Softmax function
    self.softmax_temperature = 10 # reduce later on when fine-tuning, choose moves that it believes are more likely to be successful

    # UCB formula
    self.pb_c_base = 19652
    self.pb_c_init = 1.25

    ### Training
    self.training_steps = int(700e3) # 700,000 take too long?
    self.checkpoint_interval = int(1e3)
    self.window_size = int(1e6)
    self.batch_size = 32 # 4096 -> 32 had best performance for me

    self.weight_decay = 1e-4
    self.momentum = 0.9
    # Schedule for chess and shogi, Go starts at 2e-2 immediately.
    self.learning_rate_schedule = {
        0: 2e-1,
        100e3: 2e-2,
        300e3: 2e-3,
        500e3: 2e-4
    }

In [5]:
class Node(object):

  def __init__(self, prior: float):
    self.visit_count = 0
    self.to_play = -1
    self.prior = prior
    self.value_sum = 0
    self.children = {}

  def expanded(self):
    return len(self.children) > 0

  def value(self):
    if self.visit_count == 0:
      return 0
    return self.value_sum / self.visit_count


In [6]:
def fen_to_repr(fen, repeats):
    # Split the FEN string to get the relevant parts
    parts = fen.split(' ')
    board_fen, player, castling, _, halfmove, fullmove = parts[:6]

    # 12 pieces, 4 cf, col, rep, half, total
    board = np.zeros((8, 8, 12 + 4 + 1 + 1 + 1 + 1), dtype=float)

    # Define piece order and mapping to layers
    piece_map = {'p': 0, 'n': 1, 'b': 2, 'r': 3, 'q': 4, 'k': 5,
                 'P': 6, 'N': 7, 'B': 8, 'R': 9, 'Q': 10, 'K': 11}

    # Fill the board with pieces
    row = 0
    col = 0
    for char in board_fen:
        if char.isdigit():
            col += int(char)
        elif char == '/':
            row += 1
            col = 0
        else:
            board[row, col, piece_map[char]] = 1
            col += 1

    # Castling rights encoded in four binary planes
    board[:, :, 12] = 1 if 'K' in castling else 0  # White kingside
    board[:, :, 13] = 1 if 'Q' in castling else 0  # White queenside
    board[:, :, 14] = 1 if 'k' in castling else 0  # Black kingside
    board[:, :, 15] = 1 if 'q' in castling else 0  # Black queenside

    # Player color (1 for black, 0 for white)
    board[:, :, 16] = 1 if player == 'b' else 0

    # Position repetitions
    board[:, :, 17] = repeats

    board[:, :, 18] = float(halfmove)

    # Move number (as a real value)
    board[:, :, 19] = float(fullmove)

    return board

In [7]:
def square_to_coord(square):
  #       r           c
  return (square // 8, square % 8)

knight_moves = [
    (1, 2), (1, -2),
    (2, 1), (2, -1),
    (-1, 2), (-1, -2),
    (-2, 1), (-2, -1)
]

promotion_pieces = "nbr"
rows = "abcdefgh"

queen_moves = [
    (1, 1), (1, -1), (1, 0),
    (-1, 1), (-1, -1), (-1, 0),
    (0, 1), (0, -1)
]

def coord_to_uci(fr, fc, tr, tc):
  return rows[fc] + str(fr + 1) + rows[tc] + str(tr + 1)

def action_tuple_to_index(action):
  fr, fc, plane = action
  dim2, dim3 = 8, 73
  return (fr * dim2 * dim3) + (fc * dim3) + plane

def decode_action(action):

  #piece = self.board.piece_at(move.from_square)
  #fr, fc = square_to_coord(move.from_square)
  #tr, tc = square_to_coord(move.to_square)
  fr = action[0]
  fc = action[1]

  if action[2] < 9:
    piece_idx = action[2] // 3
    move_idx = action[2] % 3
    if fr == 1: # black promo
      tr = 0
    else:
      tr = 7
    if move_idx == 0:
      tc = fc
    elif move_idx == 1:
      tc = fc - 1
    else:
      tc = fc + 1
    return coord_to_uci(fr, fc, tr, tc) + promotion_pieces[piece_idx]

  if action[2] < 17:
    tr = fr + knight_moves[action[2] - 9][0]
    tc = fc + knight_moves[action[2] - 9][1]
  else:
    dist = (action[2] - 17) // 8 + 1
    status = (action[2] - 17) % 8
    tr = fr + (queen_moves[status][0] * dist)
    tc = fc + (queen_moves[status][1] * dist)

  return coord_to_uci(fr, fc, tr, tc)


class Environment(object):

  def __init__(self, env=None):
    if env == None:
      self.board = chess.Board(chess.STARTING_FEN)
      self.board_history = []
      self.repetitions = {}

      self.update_history()
    else:
      self.board = chess.Board.copy(env.board)
      self.board_history = env.board_history.copy()
      self.repetitions = env.repetitions.copy()

  def is_terminal(self):
    return self.board.is_game_over(claim_draw=True) # claim draws

  def terminal_value(self, to_play: int) -> int: # THIS FUNCTION COULD BE VERY INCORRECT, ASSUMING CLASSIC MCTS
    # outcome = self.board.outcome(claim_draw=True)
    # if not outcome:
    #   return 0 # ERROR, SHOULD NEVER BE CALLED LIKE SO.
    winner = self.board.outcome(claim_draw=True).winner

    if winner != None:      # -> white node  ->white node
      if winner == to_play: # white_won(True) and black(1) to play or other way,
        return 1
      return 0

    return 0.5

  def generate_legal_moves(self):
    return self.board.generate_legal_moves()

  def encode_action(self, move): #12 piece types, 4 flags, 2 coords

    piece = self.board.piece_at(move.from_square)
    fr, fc = square_to_coord(move.from_square)
    tr, tc = square_to_coord(move.to_square)

    if move.promotion and move.promotion != 5: # 9 planes
      # move.promotion -> 2 : N, 3 : B, 4 : R, 5 : Q
      piece_idx = move.promotion - 2
      if fc == tc:
        move_idx = 0
      elif tc < fc: # left from white perspective
        move_idx = 1
      else:
        move_idx = 2
      plane = piece_idx * 3 + move_idx

    elif str(piece).lower() == 'n': # 8 planes
      for idx in range(len(knight_moves)):
        if fr + knight_moves[idx][0] == tr and fc + knight_moves[idx][1] == tc:
          plane = idx + 9
          break

    else: # 7 * 8 planes
      if fr < tr: # moving forward
        row_status = 0
      elif fr > tr: # back
        row_status = 1
      else: # horizontal movement
        row_status = 2

      if fc < tc: # moving right
        col_status = 0
      elif fc > tc: # left
        col_status = 1
      else: # vertical movement
        col_status = 2

      # row status and col status should never both be 2, means something stinky

      dist = max(abs(fr - tr), abs(fc - tc)) - 1
      plane = dist * 8 + (row_status * 3 + col_status) + 17

    # if decode_action((fr, fc, plane)) != str(move):
    #   print("found diff")
    #   print("action: ", fr, fc, plane)
    #   print("decoded: ", decode_action((fr, fc, plane)))
    #   print("Coords: ", fr, fc, tr, tc)
    #   print(move)
    #   print(str(move) == coord_to_uci(fr, fc, tr, tc))
    #   raise Exception("found differing in uci")

    return (fr, fc, plane)

  def apply(self, action):
    uci = decode_action(action)
    if action[0] == 1 and len(uci) == 4 and str(self.board.piece_at(action[0]*8+action[1])) == 'p':
      uci += 'q'
    if action[0] == 6 and len(uci) == 4 and str(self.board.piece_at(action[0]*8+action[1])) == 'P':
      uci += 'q'
    self.board.push_uci(uci)
    self.update_history()
    #print(self.board)

  def generate_legal_actions(self):
    moves = self.generate_legal_moves()
    actions = [self.encode_action(move) for move in moves]
    return actions

  def update_history(self):
    fen = self.board.fen()
    s_fen = fen.split(' ')
    cs_fen = ' '.join(s_fen[:-2])

    if cs_fen in self.repetitions:
      self.repetitions[cs_fen] += 1
      repeats = self.repetitions[cs_fen]
    else:
      self.repetitions[cs_fen] = 1
      repeats = 1

    halfmove_clock = self.board.halfmove_clock
    self.board_history.append((fen, repeats))

  def make_image(self, state_index=-1):
    fen, repeats= self.board_history[state_index]

    repr = fen_to_repr(fen, repeats)
    return repr

In [8]:
class Game(object):

  def __init__(self, history=None, environment=None):
    self.history = history or []
    self.child_visits = []
    self.num_actions = 4672  # action space size for chess
    self.environment = Environment(environment)

  def terminal(self):
    # Game specific termination rules.
    return self.environment.is_terminal()

  def terminal_value(self, to_play):
    # Game specific value.
    return self.environment.terminal_value(to_play)

  def legal_actions(self):
    # Game specific calculation of legal actions.
    return self.environment.generate_legal_actions()

  def clone(self):
    return Game(list(self.history), self.environment)

  def apply(self, action):
    self.history.append(action)
    self.environment.apply(action)

  def store_search_statistics(self, root):
    action_scores = np.zeros((8, 8, 73))
    sum_visits = sum(child.visit_count for child in root.children.values())

    for action, child in root.children.items():
      action_scores[action] = child.visit_count / sum_visits

    self.child_visits.append(action_scores)

    # self.child_visits.append([
    #     root.children[a].visit_count / sum_visits if a in root.children else 0
    #     for a in range(self.num_actions)
    # ])

  def make_image(self, state_index: int):
    # Game specific feature planes.
    return self.environment.make_image(state_index)

  def make_target(self, state_index: int):
    return (self.terminal_value(state_index % 2),
            self.child_visits[state_index])

  def to_play(self): # 0: White, 1: Black
    return len(self.history) % 2


In [9]:
class ReplayBuffer(object):

  def __init__(self, config: Config):
    self.window_size = config.window_size
    self.batch_size = config.batch_size
    self.buffer = []

  def save_game(self, game):
    if len(self.buffer) > self.window_size:
      self.buffer.pop(0)
    self.buffer.append(game)

  def sample_batch(self):
    # Sample uniformly across positions.
    #print(f"Sampling {self.batch_size} from: {len(self.buffer)}")
    move_sum = float(sum(len(g.history) for g in self.buffer))
    games = np.random.choice(
        self.buffer,
        size=self.batch_size,
        p=[len(g.history) / move_sum for g in self.buffer])
    game_pos = [(g, np.random.randint(len(g.history))) for g in games]
    return [(g.make_image(i), g.make_target(i)) for (g, i) in game_pos]

  def size(self):
    return len(self.buffer)

In [10]:
class Network(object):

  def __init__(self, uniform_model: bool=False):
    if uniform_model:
      self.model = make_uniform_network()
    else:
      self.model = make_network()
      self.trainable_variables = self.model.trainable_variables

  def inference(self, image): # inference for SINGLE IMAGE
      # Run the neural network model to get predictions
      image = np.expand_dims(image, axis=0)
      value, policy_logits = self.model.predict(image, verbose=0)

      value = value[0] # The value output is a scalar representing the predicted game outcome
      policy_logits = policy_logits[0]

      return value, np.array(policy_logits)

  def grad_inference(self, image):
      # Preprocess the image
      image = np.expand_dims(image, axis=0)

      # Use the `call` method for gradient-aware operations
      with tf.GradientTape() as tape:
          value, policy_logits = self.model(image, training=False) #predictable, inference-like manner, even though this operation is part of the larger training process
      # Post-process the outputs if necessary
      value = value[0]  # The value output is a scalar representing the predicted game outcome
      policy_logits = policy_logits[0]

      return value, policy_logits


  def batch_inference(self, images):
    # Run the neural network model to get predictions
    values, policy_logits = self.model.predict(images, verbose=0)
    return values, np.array(policy_logits)


  def batch_grad_inference(self, images):

      # Use the `call` method for gradient-aware operations
      with tf.GradientTape() as tape:
          values, policy_logits = self.model(images, training=False) #predictable, inference-like manner, even though this operation is part of the larger training process

      return values, np.array(policy_logits)



  def get_weights(self):
    # Returns the weights of this network.
    return self.model.get_weights()

In [11]:
class SharedStorage(object):

  def __init__(self):
    self._networks = {}

  def latest_network(self) -> Network:
    if self._networks:
      return self._networks[max(self._networks.keys())]
    else:
      return None  # policy -> uniform, value -> 0.5

  def save_network(self, step: int, network: Network):
    with open(f'networks/network_0th_{len(self._networks)}.nn', 'wb') as nn:
        pickle.dump(network, nn)
    files.download(f'networks/network_0th_{len(self._networks)}.nn')

    self._networks[step] = network


In [12]:
# training is split into two independent parts: Network training and
# self-play data generation.
# These two parts only communicate by transferring the latest network checkpoint
# from the training to the self-play, and the finished games from the self-play
# to the training.
def zero(config: Config):
  storage = SharedStorage()
  replay_buffer = ReplayBuffer(config)
  run_selfplay(config, storage, replay_buffer)
  threads = [Thread(target=run_selfplay, args=(config, storage, replay_buffer))
            for _ in range(config.num_actors)]

  for t in threads:
    t.start()

  print("Self-play data generation launched.")

  train_network(config, storage, replay_buffer)

  return storage.latest_network()



In [50]:
##################################
####### Part 1: Self-Play ########


# Each self-play job is independent of all others; it takes the latest network
# snapshot, produces a game and makes it available to the training job by
# writing it to a shared replay buffer.
def run_selfplay(config: Config, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
  while True:
    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)


# Each game is produced by starting at the initial board position, then
# repeatedly executing a Monte Carlo Tree Search to generate moves until the end
# of the game is reached.
def play_game(config: Config, network: Network=None):
  num_simulations = config.num_simulations
  num_sampling_moves = config.num_sampling_moves
  softmax_temperature = config.softmax_temperature
  if not network:
    network = Network(True)
    num_simulations = config.uniform_num_simulations
    num_sampling_moves = config.uniform_num_sampling_moves
    softmax_temperature = config.uniform_softmax_temperature

  game = Game()
  while not game.terminal() and len(game.history) < config.max_moves:
    action, root = run_mcts(config, game, network, num_simulations, num_sampling_moves, softmax_temperature)
    game.apply(action)
    game.store_search_statistics(root)
  return game


# Core Monte Carlo Tree Search algorithm.
# To decide on an action, we run N simulations, always starting at the root of
# the search tree and traversing the tree according to the UCB formula until we
# reach a leaf node.
def run_mcts(config: Config, game: Game, network: Network, num_simulations: int=800, num_sampling_moves=30, softmax_temperature=10):
  root = Node(0)
  evaluate(root, game, network)
  add_exploration_noise(config, root)

  for _ in range(num_simulations):
    node = root
    scratch_game = game.clone()
    search_path = [node]

    while node.expanded():
      action, node = select_child(config, node)
      scratch_game.apply(action)
      search_path.append(node)

    value = evaluate(node, scratch_game, network)
    backpropagate(search_path, value, scratch_game.to_play())

  return select_action(config, game, root, num_sampling_moves, softmax_temperature), root



def softmax_sample(visit_counts, temperature=10.0):
  counts, actions = zip(*visit_counts)

  if temperature == -1: # pure random
    rnd_idx = np.random.choice(len(actions))
    return None, actions[rnd_idx]

  # Apply softmax with temperature
  counts = np.array(counts)
  counts = counts / temperature  # Apply temperature scaling
  softmax_probs = np.exp(counts) / sum(np.exp(counts))
  # Sample an action based on the softmax probabilities
  rnd_idx = np.random.choice(len(actions), p=softmax_probs)
  action = actions[rnd_idx]
  return softmax_probs, action


def select_action(config: Config, game: Game, root: Node, num_sampling_moves=30, softmax_temperature=10):
  visit_counts = [(child.visit_count, action)
                  for action, child in root.children.items()]
  if len(game.history) < num_sampling_moves:
    _, action = softmax_sample(visit_counts, softmax_temperature)
  else:
    _, action = max(visit_counts)
  return action


# Select the child with the highest UCB score.
def select_child(config: Config, node: Node):
  pb_c = math.log((node.visit_count + config.pb_c_base + 1) /
                config.pb_c_base) + config.pb_c_init
  _, action, child = max((ucb_score(config, node, child, pb_c), action, child)
                         for action, child in node.children.items())
  return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: Config, parent: Node, child: Node, pb_c):
  pb_C = pb_c * math.sqrt(parent.visit_count) / (child.visit_count + 1)

  prior_score = pb_C * child.prior
  value_score = child.value()
  return prior_score + value_score


# We use the neural network to obtain a value and policy prediction.
def evaluate(node: Node, game: Game, network: Network):
  value, policy_logits = network.inference(game.make_image(-1))

  # Expand the node.
  node.to_play = game.to_play()

  policy = {a: math.exp(policy_logits[a]) for a in game.legal_actions()}
  policy_sum = sum(policy.values())
  for action, p in policy.items():
    node.children[action] = Node(p / policy_sum)
  return value


# At the end of a simulation, we propagate the evaluation all the way up the
# tree to the root.
def backpropagate(search_path: List[Node], value: float, to_play):
  for node in search_path:
    node.value_sum += value if node.to_play == to_play else (1 - value)
    node.visit_count += 1


# At the start of each search, we add dirichlet noise to the prior of the root
# to encourage the search to explore new actions.
def add_exploration_noise(config: Config, node: Node):
  actions = node.children.keys()
  noise = np.random.gamma(config.root_dirichlet_alpha, 1, len(actions))
  frac = config.root_exploration_fraction
  for a, n in zip(actions, noise):
    node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac


######### End Self-Play ##########
##################################

##################################
####### Part 2: Training #########

class ZeroLearningRateSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, lr_schedule):
        super(ZeroLearningRateSchedule, self).__init__()
        # Convert keys to integers
        self.lr_schedule = {int(k): v for k, v in lr_schedule.items()}
        self.lr_schedule_keys = sorted(self.lr_schedule)

    def __call__(self, step):
        learning_rate = self.lr_schedule[self.lr_schedule_keys[0]]
        for i in range(1, len(self.lr_schedule_keys)):
            if step < self.lr_schedule_keys[i]:
                break
            learning_rate = self.lr_schedule[self.lr_schedule_keys[i]]
        return learning_rate

    def get_config(self):
        return {'lr_schedule': self.lr_schedule}


def train_network(config: Config, storage: SharedStorage,
                  replay_buffer: ReplayBuffer):
  network = Network()
  learning_rate_schedule = ZeroLearningRateSchedule(config.learning_rate_schedule)
  optimizer = tf.keras.optimizers.SGD(
      learning_rate_schedule,
      config.momentum,
      nesterov=True # testin nestrov
      )
  print(f"Training network with {config.training_steps} steps and batch size {config.batch_size}.")
  print(f"Optimizer configuration: ")
  print(optimizer.get_config())
  for i in range(config.training_steps):
    batch = replay_buffer.sample_batch()
    update_weights(optimizer, network, batch, config.weight_decay)
    if i % config.checkpoint_interval == config.checkpoint_interval-1:
      print(f"Checkpoint at training step: {i}")
      storage.save_network(i, network)
  storage.save_network(config.training_steps, network)


def update_weights(optimizer: tf.keras.optimizers, network: Network, batch, weight_decay: float):
    images, targets = zip(*batch)
    target_values, target_policies = zip(*targets)

    images = np.stack(images)

    target_values = np.stack(target_values)
    target_policies = np.stack(target_policies)

    with tf.GradientTape() as tape:
        values, policy_logits = network.batch_grad_inference(images)

        reshaped_target_policies = tf.reshape(target_policies, [-1, 4672]) # 8 * 8 * 73 = 4672 action space
        reshaped_policy_logits = tf.reshape(policy_logits, [-1, 4672])


        # Calculate the losses for the entire batch
        value_loss = tf.losses.mean_squared_error(target_values, values)
        policy_loss = tf.nn.softmax_cross_entropy_with_logits(labels=reshaped_target_policies, logits=reshaped_policy_logits)
        loss = tf.reduce_mean(value_loss + policy_loss)  # Average over the batch

        # Regularization (if applicable)
        for weights in network.trainable_variables:
            loss += weight_decay * tf.nn.l2_loss(weights)

        print(f"Loss: {loss}")

    gradients = tape.gradient(loss, network.trainable_variables)

    optimizer.apply_gradients(zip(gradients, network.trainable_variables))




######### End Training ###########
##################################


def launch_job(f, *args):
  f(*args)

In [14]:
!mkdir tmp
!mkdir gifs
!mkdir uci_pgn

def svgs_to_pngs(svgs_data):
  for i in range(len(svgs_data)):
    cairosvg.svg2png(bytestring=svgs_data[i], write_to=f"tmp/board_img_{i}.png")


def svgs_to_gif(svgs, game_idx):
  svgs_to_pngs(svgs)

  images = []
  for i in range(len(svgs)):  # Assumes all PNGs are in the current directory
      images.append(Image.open(f"tmp/board_img_{i}.png"))

  images[0].save(f'gifs/game_{game_idx}_moves_{len(svgs)}.gif', save_all=True, append_images=images[1:], duration=200, loop=0)

  return f'gifs/game_{game_idx}_moves_{len(svgs)}.gif'


def display_svg(board: chess.Board, svgs, show_svg: bool=False):
  boardsvg = chess.svg.board(board, size=350)
  svgs.append(boardsvg)
  if show_svg:
    display(IPython.display.HTML(boardsvg))

def visualize_game(config: Config, network: Network=None, show_svg: bool=True, store_gif: bool=False, game_idx: int=0):
  num_simulations = config.num_simulations
  num_sampling_moves = config.num_sampling_moves
  softmax_temperature = config.softmax_temperature
  if not network:
    network = Network(True)
    num_simulations = config.uniform_num_simulations
    num_sampling_moves = config.uniform_num_sampling_moves
    softmax_temperature = config.uniform_softmax_temperature

  game = Game()
  svgs = []

  # initial board
  display_svg(game.environment.board, svgs, show_svg)

  while not game.terminal() and len(game.history) < config.max_moves:
    action, root = run_mcts(config, game, network, num_simulations, num_sampling_moves, softmax_temperature)
    game.apply(action)
    display_svg(game.environment.board, svgs, show_svg)
    game.store_search_statistics(root)

  print("Game Ended with ply: ", game.environment.board.ply())

  uci_pgn = ""
  count = 2
  for move in game.environment.board.move_stack:
    if count % 10 == 0:
      uci_pgn += "\n"
    if count % 2 == 0:
      uci_pgn += f"{count//2}."
    uci_pgn += move.uci() + " "
    count += 1

  with open(f"uci_pgn/game_{game_idx}_moves_{len(svgs)}.txt", "w") as text_file:
      text_file.write(uci_pgn)

  if store_gif:
    file_name = svgs_to_gif(svgs, game_idx)
    print("Gif created at: ", file_name)

  return game

mkdir: cannot create directory ‘tmp’: File exists
mkdir: cannot create directory ‘gifs’: File exists
mkdir: cannot create directory ‘uci_pgn’: File exists


In [15]:
!mkdir games

def dump_game(game, name):
  with open(f'games/{name}.game', 'wb') as game_file:
    pickle.dump(game, game_file)
  return f'games/{name}.game'

mkdir: cannot create directory ‘games’: File exists


In [41]:
def test_visualize_game(config:Config=None, network:Network=None, SVG=True, GIF=True):
  if not config:
    config = Config()

  for i in range(1000):
    game = visualize_game(config, network=network, show_svg=SVG, store_gif=GIF, game_idx=i)
    name = dump_game(game, f"uniform_{i}")
    print("Game stored at: ", name)

In [17]:
#test_visualize_game()

In [18]:
!mkdir pklbuf

class PickledBuffer(object):

  def __init__(self, name: str = "pickled_buffer", max_buffer_size:int=128):
    self.max_buffer_size = max_buffer_size
    self.idx = 0
    self.name = name
    self.buffer = []

  def offload(self):
    print(f"Offloading pickled buffer '{self.name}' with {len(self.buffer)} games at 'pklbuf/{self.name}_{self.idx}.pb'.")
    if self.buffer:
      with open(f'pklbuf/{self.name}_{self.idx}.pb', 'wb') as pickled_buffer:
        pickle.dump(self.buffer, pickled_buffer)
      self.buffer.clear()
      files.download(f'pklbuf/{self.name}_{self.idx}.pb')
      self.idx += 1

  def save_game(self, game):
    if len(self.buffer) >= self.max_buffer_size:
      self.offload()
    self.buffer.append(game)

  def size(self):
    return len(self.buffer)

  def pickles(self):
    return self.idx



def merge_pickled_buffers(config: Config, pickles:int = None, name:str ="pickled_buffer") -> ReplayBuffer:
  replay_buffer = ReplayBuffer(config)
  if pickles:
    for i in range(pickles):
      with open(f'pklbuf/{name}_{i}.pb', 'rb') as pickled_buffer:
        pickled_buffer = pickle.load(pickled_buffer)
      replay_buffer.buffer += pickled_buffer
  else:
    idx = 0
    while True:
      try:
        with open(f'pklbuf/{name}_{idx}.pb', 'rb') as pickled_buffer:
          pickled_buffer = pickle.load(pickled_buffer)
          replay_buffer.extend(pickled_buffer)
          idx += 1
      except:
        break
  return replay_buffer

mkdir: cannot create directory ‘pklbuf’: File exists


In [19]:
!mkdir buffer
def fill_buffer(replay_buffer, config: Config = None, storage: SharedStorage=None, num_games: int=None):
  if not config:
    config = Config()
  if not storage:
    storage = SharedStorage()

  if not num_games:
    num_games = config.window_size

  for i in range(num_games):

    network = storage.latest_network()
    game = play_game(config, network)
    replay_buffer.save_game(game)
    if type(replay_buffer) == ReplayBuffer and i % 128 == 127:
      with open(f'buffer/uniform_step_0_sim.buffer', 'wb') as buffer:
        pickle.dump(replay_buffer, buffer)
      print("Checkpoint: ", i // 128)

mkdir: cannot create directory ‘buffer’: File exists


In [44]:
config = Config()
config.num_simulations=10
pickled_buffer = PickledBuffer(name="pickled_buffer", max_buffer_size=256)
storage = SharedStorage()
config.uniform_num_simulations=1
fill_buffer(config=config, replay_buffer=pickled_buffer, storage=storage, num_games=int(1e5))

KeyboardInterrupt: 

In [None]:
# with open(f'buffer/uniform_step_0_sim.buffer', 'rb') as buffer:
#   replay_buffer = pickle.load(buffer)

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [35]:
pickled_buffer.offload() # clear buffer
pickles = pickled_buffer.pickles()

replay_buffer = merge_pickled_buffers(config, pickles=pickles)
print(replay_buffer.size())

Offloading pickled buffer 'pickled_buffer' with 62 games at 'pklbuf/pickled_buffer_0.pb'.


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

62


In [None]:
!mkdir networks
replay_buffer.batch_size=32
config.training_steps=1000
train_network(config, storage, replay_buffer)

mkdir: cannot create directory ‘networks’: File exists
Training network with 1000 steps and batch size 32.
Optimizer configuration: 
{'name': 'SGD', 'weight_decay': None, 'clipnorm': None, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'jit_compile': False, 'is_legacy_optimizer': False, 'learning_rate': {'module': None, 'class_name': 'ZeroLearningRateSchedule', 'config': {'lr_schedule': {0: 0.2, 100000: 0.02, 300000: 0.002, 500000: 0.0002}}, 'registered_name': 'ZeroLearningRateSchedule'}, 'momentum': 0.9, 'nesterov': True}
Sampling 32 from: 62
Loss: 47.891963958740234
Sampling 32 from: 62
Loss: 1.1820210913866364e+20
Sampling 32 from: 62
Loss: 3.8027931202610864e+24
Sampling 32 from: 62
Loss: 7.940208845242005e+27
Sampling 32 from: 62
Loss: 2.0516471544975988e+30
Sampling 32 from: 62
Loss: 3.447212321864232e+32


In [42]:
config = Config()
network = storage.latest_network()
test_visualize_game(config, network, SVG=True, GIF=True)

KeyboardInterrupt: 

In [None]:
game = Game()
image = game.make_image()