In [1]:
import os

import chess
import chess.svg
from jax import random as jrandom
import numpy as np

In [2]:
from searchless_chess.src import tokenizer
from searchless_chess.src import training_utils
from searchless_chess.src import transformer
from searchless_chess.src import utils
from searchless_chess.src.engines import engine
from searchless_chess.src.engines import neural_engines
import searchless_chess.src.engines.constants as constants

In [3]:
# @title Create the predictor.

policy = 'action_value'
num_return_buckets = 128

match policy:
  case 'action_value':
    output_size = num_return_buckets
  case 'behavioral_cloning':
    output_size = utils.NUM_ACTIONS
  case 'state_value':
    output_size = num_return_buckets
  case _:
    raise ValueError(f'Unknown policy: {policy}')

predictor_config = transformer.TransformerConfig(
    vocab_size=utils.NUM_ACTIONS,
    output_size=output_size,
    pos_encodings=transformer.PositionalEncodings.LEARNED,
    max_sequence_length=tokenizer.SEQUENCE_LENGTH + 2,
    num_heads=4,
    num_layers=4,
    embedding_dim=64,
    apply_post_ln=True,
    apply_qk_layernorm=False,
    use_causal_mask=False,
)

predictor = transformer.build_transformer_predictor(config=predictor_config)

In [4]:
# @title Load the predictor parameters

checkpoint_dir = os.path.join(
    os.getcwd(),
    f'../checkpoints/9M',
)
dummy_params = predictor.initial_params(
    rng=jrandom.PRNGKey(6400000),
    targets=np.zeros((1, 1), dtype=np.uint32),
)
params = training_utils.load_parameters(
    checkpoint_dir=checkpoint_dir,
    params=dummy_params,
    use_ema_params=True,
    step=-1,
)

Expected keys: dict_keys(['embed', 'embed_1', 'layer_norm', 'multi_head_dot_product_attention/linear', 'multi_head_dot_product_attention/linear_1', 'multi_head_dot_product_attention/linear_2', 'multi_head_dot_product_attention/linear_3', 'layer_norm_1', 'linear', 'linear_1', 'linear_2', 'layer_norm_2', 'multi_head_dot_product_attention_1/linear', 'multi_head_dot_product_attention_1/linear_1', 'multi_head_dot_product_attention_1/linear_2', 'multi_head_dot_product_attention_1/linear_3', 'layer_norm_3', 'linear_3', 'linear_4', 'linear_5', 'layer_norm_4', 'multi_head_dot_product_attention_2/linear', 'multi_head_dot_product_attention_2/linear_1', 'multi_head_dot_product_attention_2/linear_2', 'multi_head_dot_product_attention_2/linear_3', 'layer_norm_5', 'linear_6', 'linear_7', 'linear_8', 'layer_norm_6', 'multi_head_dot_product_attention_3/linear', 'multi_head_dot_product_attention_3/linear_1', 'multi_head_dot_product_attention_3/linear_2', 'multi_head_dot_product_attention_3/linear_3', 'l

ValueError: Dict key mismatch; expected keys: ['embed', 'embed_1', 'layer_norm', 'layer_norm_1', 'layer_norm_10', 'layer_norm_11', 'layer_norm_12', 'layer_norm_13', 'layer_norm_14', 'layer_norm_15', 'layer_norm_16', 'layer_norm_2', 'layer_norm_3', 'layer_norm_4', 'layer_norm_5', 'layer_norm_6', 'layer_norm_7', 'layer_norm_8', 'layer_norm_9', 'linear', 'linear_1', 'linear_10', 'linear_11', 'linear_12', 'linear_13', 'linear_14', 'linear_15', 'linear_16', 'linear_17', 'linear_18', 'linear_19', 'linear_2', 'linear_20', 'linear_21', 'linear_22', 'linear_23', 'linear_24', 'linear_3', 'linear_4', 'linear_5', 'linear_6', 'linear_7', 'linear_8', 'linear_9', 'multi_head_dot_product_attention/linear', 'multi_head_dot_product_attention/linear_1', 'multi_head_dot_product_attention/linear_2', 'multi_head_dot_product_attention/linear_3', 'multi_head_dot_product_attention_1/linear', 'multi_head_dot_product_attention_1/linear_1', 'multi_head_dot_product_attention_1/linear_2', 'multi_head_dot_product_attention_1/linear_3', 'multi_head_dot_product_attention_2/linear', 'multi_head_dot_product_attention_2/linear_1', 'multi_head_dot_product_attention_2/linear_2', 'multi_head_dot_product_attention_2/linear_3', 'multi_head_dot_product_attention_3/linear', 'multi_head_dot_product_attention_3/linear_1', 'multi_head_dot_product_attention_3/linear_2', 'multi_head_dot_product_attention_3/linear_3', 'multi_head_dot_product_attention_4/linear', 'multi_head_dot_product_attention_4/linear_1', 'multi_head_dot_product_attention_4/linear_2', 'multi_head_dot_product_attention_4/linear_3', 'multi_head_dot_product_attention_5/linear', 'multi_head_dot_product_attention_5/linear_1', 'multi_head_dot_product_attention_5/linear_2', 'multi_head_dot_product_attention_5/linear_3', 'multi_head_dot_product_attention_6/linear', 'multi_head_dot_product_attention_6/linear_1', 'multi_head_dot_product_attention_6/linear_2', 'multi_head_dot_product_attention_6/linear_3', 'multi_head_dot_product_attention_7/linear', 'multi_head_dot_product_attention_7/linear_1', 'multi_head_dot_product_attention_7/linear_2', 'multi_head_dot_product_attention_7/linear_3']; dict: {'embed': {'embeddings': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(1968, 64), shape=(1968, 64), strict=True)}, 'embed_1': {'embeddings': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(79, 64), shape=(79, 64), strict=True)}, 'layer_norm': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_1': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_2': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_3': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_4': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_5': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_6': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_7': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'layer_norm_8': {'offset': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True), 'scale': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64,), shape=(64,), strict=True)}, 'linear': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_1': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_10': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_11': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(256, 64), shape=(256, 64), strict=True)}, 'linear_12': {'b': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(128,), shape=(128,), strict=True), 'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 128), shape=(64, 128), strict=True)}, 'linear_2': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(256, 64), shape=(256, 64), strict=True)}, 'linear_3': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_4': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_5': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(256, 64), shape=(256, 64), strict=True)}, 'linear_6': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_7': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'linear_8': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(256, 64), shape=(256, 64), strict=True)}, 'linear_9': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 256), shape=(64, 256), strict=True)}, 'multi_head_dot_product_attention/linear': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention/linear_1': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention/linear_2': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention/linear_3': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_1/linear': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_1/linear_1': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_1/linear_2': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_1/linear_3': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_2/linear': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_2/linear_1': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_2/linear_2': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_2/linear_3': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_3/linear': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_3/linear_1': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_3/linear_2': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}, 'multi_head_dot_product_attention_3/linear_3': {'w': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host), global_shape=(64, 64), shape=(64, 64), strict=True)}}.

In [40]:
# @title Create the engine

predict_fn = neural_engines.wrap_predict_fn(predictor, params, batch_size=1)
_, return_buckets_values = utils.get_uniform_buckets_edges_values(
    num_return_buckets
)

neural_engine = neural_engines.ENGINE_FROM_POLICY[policy](
    return_buckets_values=return_buckets_values,
    predict_fn=predict_fn,
    temperature=0.005,
)

NameError: name 'params' is not defined

In [48]:
neural_engine = constants.ENGINE_BUILDERS['9M']()

In [49]:
# @title Play a move with the agent
board = chess.Board()
best_move = neural_engine.play(board)
print(f'Best move: {best_move}')
(neural_engine.analyse(board)['log_probs'])

Best move: e2e4


array([[-16.969692, -18.26804 , -17.82215 , ..., -19.248909, -19.558874,
        -17.525417],
       [-17.43915 , -19.513908, -18.2545  , ..., -18.77199 , -19.534468,
        -19.028124],
       [-18.3756  , -21.608381, -20.526136, ..., -18.933329, -19.637878,
        -18.978556],
       ...,
       [-16.49742 , -18.085724, -17.889166, ..., -18.715448, -18.902546,
        -17.23208 ],
       [-17.110027, -19.15794 , -18.437391, ..., -19.49886 , -19.90754 ,
        -18.534254],
       [-16.919296, -18.759192, -18.264711, ..., -19.109028, -19.664448,
        -17.625923]], shape=(20, 128), dtype=float32)

In [None]:
# @title Compute the win percentages for all legal moves

board = chess.Board()
results = neural_engine.analyse(board)
buckets_log_probs = results['log_probs']

# Compute the expected return.
win_probs = np.inner(np.exp(buckets_log_probs), return_buckets_values)
sorted_legal_moves = engine.get_ordered_legal_moves(board)

print(board.fen())
print(f'Win percentages:')
for i in np.argsort(win_probs)[::-1]:
  print(f'  {sorted_legal_moves[i].uci()} -> {100*win_probs[i]:.1f}%')

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.