In [1]:
import os

import chess
import chess.svg
from jax import random as jrandom
import numpy as np

In [2]:
from searchless_chess.src import tokenizer
from searchless_chess.src import training_utils
from searchless_chess.src import transformer
from searchless_chess.src import utils
from searchless_chess.src.engines import engine
from searchless_chess.src.engines import neural_engines
import searchless_chess.src.engines.constants as constants

In [3]:
# @title Create the predictor.

policy = 'action_value'
num_return_buckets = 128

match policy:
  case 'action_value':
    output_size = num_return_buckets
  case 'behavioral_cloning':
    output_size = utils.NUM_ACTIONS
  case 'state_value':
    output_size = num_return_buckets
  case _:
    raise ValueError(f'Unknown policy: {policy}')

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=8,
    num_layers=8,
    embedding_dim=256,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = transformer.build_transformer_predictor(config=predictor_config)

In [4]:
# @title Load the predictor parameters

checkpoint_dir = os.path.join(
    os.getcwd(),
    f'../checkpoints/9M',
)
dummy_params = predictor.initial_params(
    rng=jrandom.PRNGKey(6400000),
    targets=np.zeros((1, 1), dtype=np.uint32),
)
params = training_utils.load_parameters(
    checkpoint_dir=checkpoint_dir,
    params=dummy_params,
    use_ema_params=True,
    step=-1,
)

Expected keys: dict_keys(['embed', 'embed_1', 'layer_norm', 'multi_head_dot_product_attention/linear', 'multi_head_dot_product_attention/linear_1', 'multi_head_dot_product_attention/linear_2', 'multi_head_dot_product_attention/linear_3', 'layer_norm_1', 'linear', 'linear_1', 'linear_2', 'layer_norm_2', 'multi_head_dot_product_attention_1/linear', 'multi_head_dot_product_attention_1/linear_1', 'multi_head_dot_product_attention_1/linear_2', 'multi_head_dot_product_attention_1/linear_3', 'layer_norm_3', 'linear_3', 'linear_4', 'linear_5', 'layer_norm_4', 'multi_head_dot_product_attention_2/linear', 'multi_head_dot_product_attention_2/linear_1', 'multi_head_dot_product_attention_2/linear_2', 'multi_head_dot_product_attention_2/linear_3', 'layer_norm_5', 'linear_6', 'linear_7', 'linear_8', 'layer_norm_6', 'multi_head_dot_product_attention_3/linear', 'multi_head_dot_product_attention_3/linear_1', 'multi_head_dot_product_attention_3/linear_2', 'multi_head_dot_product_attention_3/linear_3', 'l

In [5]:
# @title Create the engine

predict_fn = neural_engines.wrap_predict_fn(predictor, params, batch_size=1)
_, return_buckets_values = utils.get_uniform_buckets_edges_values(
    num_return_buckets
)

neural_engine = neural_engines.ENGINE_FROM_POLICY[policy](
    return_buckets_values=return_buckets_values,
    predict_fn=predict_fn,
    temperature=0.005,
)

In [6]:
# @title Play a move with the agent
board = chess.Board()
best_move = neural_engine.play(board)
print(f'Best move: {best_move}')
(neural_engine.analyse(board)['log_probs'])

Best move: g1f3


array([[-16.972721, -18.297773, -17.786018, ..., -19.471888, -19.739819,
        -17.669365],
       [-17.332315, -19.460886, -18.252579, ..., -18.685335, -19.400723,
        -18.843721],
       [-18.227547, -21.45101 , -20.34803 , ..., -18.866692, -19.54095 ,
        -18.878246],
       ...,
       [-16.26824 , -17.879591, -17.651842, ..., -18.692684, -18.872377,
        -17.201775],
       [-17.052387, -19.141958, -18.439047, ..., -19.566021, -19.94304 ,
        -18.54868 ],
       [-16.913742, -18.775784, -18.253054, ..., -19.214146, -19.735785,
        -17.74015 ]], shape=(20, 128), dtype=float32)

In [13]:
import math
def win_probability_to_centipawns(win_probability: float) -> int:
  """Returns the centipawn score converted from the win probability (in [0, 1]).

  Args:
    win_probability: The win probability in the range [0, 1].
  """
  if not 0 <= win_probability <= 1:
    raise ValueError("Win probability must be in the range [0, 1].")
  
  centipawns = -1 / 0.00368208 * math.log((1 - win_probability) / win_probability)
  return int(centipawns)

In [19]:
# @title Compute the win percentages for all legal moves

board = chess.Board()
results = neural_engine.analyse(board)
buckets_log_probs = results['log_probs']

# Compute the expected return.
win_probs = np.inner(np.exp(buckets_log_probs), return_buckets_values)
sorted_legal_moves = engine.get_ordered_legal_moves(board)

print(sorted_legal_moves)
print(f'Win percentages:')
print(max(win_probs))
for i in np.argsort(win_probs)[::-1]:
  print(i)
  cp = win_probability_to_centipawns(win_probs[i])
  print(f'  {sorted_legal_moves[i].uci()} -> {100*win_probs[i]:.1f}% cp: {cp}')

[Move.from_uci('b1a3'), Move.from_uci('b1c3'), Move.from_uci('g1f3'), Move.from_uci('g1h3'), Move.from_uci('a2a3'), Move.from_uci('a2a4'), Move.from_uci('b2b3'), Move.from_uci('b2b4'), Move.from_uci('c2c3'), Move.from_uci('c2c4'), Move.from_uci('d2d3'), Move.from_uci('d2d4'), Move.from_uci('e2e3'), Move.from_uci('e2e4'), Move.from_uci('f2f3'), Move.from_uci('f2f4'), Move.from_uci('g2g3'), Move.from_uci('g2g4'), Move.from_uci('h2h3'), Move.from_uci('h2h4')]
Win percentages:
0.5351641747458535
13
  e2e4 -> 53.5% cp: 38
11
  d2d4 -> 53.1% cp: 33
2
  g1f3 -> 52.6% cp: 28
9
  c2c4 -> 52.4% cp: 26
12
  e2e3 -> 51.6% cp: 17
1
  b1c3 -> 51.3% cp: 13
16
  g2g3 -> 50.8% cp: 8
8
  c2c3 -> 50.0% cp: 0
4
  a2a3 -> 49.6% cp: -4
10
  d2d3 -> 49.5% cp: -5
6
  b2b3 -> 49.1% cp: -9
18
  h2h3 -> 48.8% cp: -12
5
  a2a4 -> 48.0% cp: -21
15
  f2f4 -> 48.0% cp: -22
19
  h2h4 -> 47.2% cp: -30
7
  b2b4 -> 45.7% cp: -46
0
  b1a3 -> 44.9% cp: -56
3
  g1h3 -> 44.1% cp: -64
14
  f2f3 -> 43.7% cp: -68
17
  g2g4 -> 