<a href="https://colab.research.google.com/github/caleb-code/self-attention-chess/blob/main/Chess_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install chess
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install berserk

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/6.1 MB[0m [31m74.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/6.1 MB[0m [31m43.4 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.1/6.1 MB[0m [31m69.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=0b83eda19749b195b8fdf49c655be0db0daac58c7a6f0a832095a17a3

In [None]:
#@title Initialize functions
import numpy as np
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        residual = x
        out = self.linear(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = out + residual
        out = self.norm(out)
        return out

class ChessAttentionModel(nn.Module):
    def __init__(self, hidden_dim, num_squares=64, feature_dim=12):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_squares = num_squares  # Must be 64 for chessboard
        self.feature_dim = feature_dim  # 12 for 6 white + 6 black pieces
        self.scale = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32).to('cpu'))  # Ensure scalar on CPU for stability

        self.input_projection = nn.Linear(feature_dim, hidden_dim)

        # Correct shape: [1, num_squares=64, hidden_dim]
        # Safeguard: Assert to catch init errors
        pos_shape = (1, num_squares, hidden_dim)
        self.pos_embeddings = nn.Parameter(torch.randn(pos_shape))

        self.w_V = nn.Linear(hidden_dim, hidden_dim)
        self.w_K = nn.Linear(hidden_dim, hidden_dim)
        self.w_Q = nn.Linear(hidden_dim, hidden_dim)
        self.final_linear = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(0.12)
        self.norm = nn.LayerNorm(hidden_dim)
        self.layers = nn.ModuleList([ResidualBlock(hidden_dim) for _ in range(4)])
        self.tanh = nn.Tanh()

    def attention(self, x):
        V = self.w_V(x)
        K = self.w_K(x)
        Q = self.w_Q(x)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        attn_probs = torch.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_probs, V)

    def forward(self, x):
        # Input: [batch_size, 64, feature_dim=12]
        x = self.input_projection(x)  # [batch_size, 64, hidden_dim]

        # Debug print (remove after fixing)
        # print(f"x shape: {x.shape}, pos_embeddings shape: {self.pos_embeddings.shape}")

        # Safeguard: Assert shapes match for addition
        assert x.shape[1] == self.pos_embeddings.shape[1], f"Sequence length mismatch: x dim1={x.shape[1]}, pos dim1={self.pos_embeddings.shape[1]}"

        x = x + self.pos_embeddings  # Now broadcasts correctly: [batch, 64, hidden_dim]
        x = self.attention(x)
        x_o = self.norm(self.dropout(x) + x)  # Residual
        for layer in self.layers:
            x_o = layer(x_o)
        x_o = x_o.mean(dim=1)  # Pool: [batch_size, hidden_dim]
        return self.tanh(self.final_linear(x_o))  # [batch_size, 1]

def board_to_tensor(board):
    tensor = torch.zeros(1, 64, 12, dtype=torch.float32)
    for i in range(64):
        piece = board.piece_at(i)
        if piece:
            piece_idx = piece.piece_type - 1  # 0-based: 0=P, 1=N, 2=B, 3=R, 4=Q, 5=K
            if piece.color == chess.WHITE:
                tensor[0, i, piece_idx] = 1.0
            else:  # Black
                tensor[0, i, 6 + piece_idx] = 1.0

    return tensor
def get_first_cp(item):
    for eval_entry in item.get('evals', []):
        for pv_entry in eval_entry.get('pvs', []):
            if 'cp' in pv_entry:
                return pv_entry['cp']
    return None  # fallback if nothing found

In [None]:
#@title Download dataset from lichess
!apt install zstd
!wget https://database.lichess.org/lichess_db_eval.jsonl.zst
!zstd -d lichess_db_eval.jsonl.zst -o data.json
!pip install ijson

import json
import chess
from tqdm import tqdm

x_list, y_list = [], []

max_items = 10_000_000

with open("data.json", "r") as f:
    for count, line in enumerate(tqdm(f)):
        if count >= max_items:
            break
        item = json.loads(line)

        cp = get_first_cp(item)
        if cp is not None:
            x_list.append(board_to_tensor(chess.Board(item['fen'])))
            y_list.append(torch.tanh(torch.tensor(cp / 400, dtype=torch.float32)))

In [None]:
#@title Save data to .pt
x_tensor = torch.stack(x_list)
y_tensor = torch.stack(y_list)
torch.save({'x': x_tensor, 'y': y_tensor}, 'data_chunk.pt')

In [None]:
#@title Save model to .pt
torch.save(model, "model.pt")

In [None]:
#@title Load dataset
l = torch.load("data_chunk.pt")
x_lim = -1 #@param
y_lim = -1 #@param
x_tensor = l['x'][0:x_lim]
y_tensor = l['y'][0:y_lim]

In [None]:
#@title Initialize model for training
hidden_dim = 512 #@param
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChessAttentionModel(hidden_dim=hidden_dim).to(device)
model = torch.compile(model)
print("Model compiled successfully!")
torch.set_float32_matmul_precision("high")

Model compiled successfully!


In [None]:
model = torch.load("model.pt", weights_only = False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.compile(model.to(device))
torch.set_float32_matmul_precision("high")

In [None]:
#@title Count model parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(model)
print(f"The model has {num_params:,} trainable parameters.")

The model has 1,883,649 trainable parameters.


In [None]:
#@title Train Model
from tqdm import tqdm
print(device)
num_epochs = 50 #@param
batch_size = 64 #@param
learning_rate = 0.0001 #@param

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print("Starting training on loaded data...")

# Create DataLoader for batching
dataset = torch.utils.data.TensorDataset(x_tensor, y_tensor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(dataloader)
    for inputs, labels in pbar:
        # Fix: Remove the extra dimension from inputs
        inputs = inputs.squeeze(1).to(device)
        labels = labels.to(device).unsqueeze(1) # Add a dimension to labels to match model output

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0) # Multiply by batch size for correct average
        if pbar.n != 0:
          pbar.set_postfix({"loss": f"{running_loss/(64*pbar.n):.2f}"})
    epoch_loss = running_loss / len(dataset)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

print("Training on loaded data finished.")

cuda
Starting training on loaded data...


100%|██████████| 40803/40803 [04:10<00:00, 163.09it/s, loss=0.19]


Epoch [1/50], Loss: 0.1882


100%|██████████| 40803/40803 [04:08<00:00, 163.92it/s, loss=0.19]


Epoch [2/50], Loss: 0.1866


100%|██████████| 40803/40803 [04:11<00:00, 162.39it/s, loss=0.19]


Epoch [3/50], Loss: 0.1856


100%|██████████| 40803/40803 [04:08<00:00, 164.02it/s, loss=0.19]


Epoch [4/50], Loss: 0.1860


100%|██████████| 40803/40803 [04:09<00:00, 163.34it/s, loss=0.19]


Epoch [5/50], Loss: 0.1861


100%|██████████| 40803/40803 [04:11<00:00, 162.26it/s, loss=0.19]


Epoch [6/50], Loss: 0.1858


100%|██████████| 40803/40803 [04:09<00:00, 163.51it/s, loss=0.19]


Epoch [7/50], Loss: 0.1852


100%|██████████| 40803/40803 [04:08<00:00, 164.06it/s, loss=0.18]


Epoch [8/50], Loss: 0.1841


100%|██████████| 40803/40803 [04:09<00:00, 163.66it/s, loss=0.18]


Epoch [9/50], Loss: 0.1838


100%|██████████| 40803/40803 [04:08<00:00, 163.97it/s, loss=0.18]


Epoch [10/50], Loss: 0.1833


100%|██████████| 40803/40803 [04:11<00:00, 161.94it/s, loss=0.18]


Epoch [11/50], Loss: 0.1834


100%|██████████| 40803/40803 [04:10<00:00, 162.89it/s, loss=0.18]


Epoch [12/50], Loss: 0.1836


100%|██████████| 40803/40803 [04:10<00:00, 162.86it/s, loss=0.18]


Epoch [13/50], Loss: 0.1832


100%|██████████| 40803/40803 [04:09<00:00, 163.25it/s, loss=0.18]


Epoch [14/50], Loss: 0.1830


100%|██████████| 40803/40803 [04:09<00:00, 163.24it/s, loss=0.18]


Epoch [15/50], Loss: 0.1836


 23%|██▎       | 9425/40803 [00:57<03:08, 166.06it/s, loss=0.18]

In [None]:
def get_best_move(board: chess.Board, model, turn):
  possible_boards = []
  possible_moves = []
  model.eval()
  for i in board.legal_moves:
    b = board.copy()
    b.push(i)
    possible_moves.append(i)
    if i == chess.Move.from_uci("h5h4"):
      print(len(possible_moves)-1)
    possible_boards.append(board_to_tensor(b).squeeze(0))
  p_b = torch.stack(possible_boards).to(device)
  print(p_b.shape)
  prob_tens = model(p_b).cpu().detach()
  idx = np.argmax(prob_tens.numpy())
  print(torch.std(prob_tens), torch.mean(prob_tens), torch.max(prob_tens), torch.min(prob_tens))
  return possible_moves[idx]

In [None]:
def play_text_game(model):
    board = chess.Board()
    print("Starting a text-based chess game against the model.")
    print(board)

    while not board.is_game_over():
        # Get player move
        # Get model move
        print("\nModel is thinking...")
        model_move = get_best_move(board, model, board.turn)
        board.push(model_move)
        print("Model's move:")
        print(board)

        if board.is_game_over():
            print(f"Game over: {board.result()}")
            break

        while True:
            move_uci = input("Enter your move (in UCI format, e.g., e2e4): ")
            try:
                move = board.parse_uci(move_uci)
                if move in board.legal_moves:
                    board.push(move)
                    print("\nYour move:")
                    print(board)
                    break
                else:
                    print("Illegal move. Try again.")
            except ValueError:
                print("Invalid UCI format. Try again.")

        if board.is_game_over():
            print(f"Game over: {board.result()}")
            break



    print(f"Game over: {board.result()}")

# Assuming 'model' is already loaded from a previous cell
# If not, make sure to load it here:
# model = torch.load("model.pt", weights_only=False)

play_text_game(model)