In [3]:
import torch
import torch.nn as nn

from datasets import load_dataset, Dataset as HFDataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer
import numpy as np
from prodigyopt import Prodigy

from IPython.display import HTML
import re

from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, models, trainers
from tokenizers.pre_tokenizers import Sequence, Split
from tokenizers.normalizers import Replace, Sequence as NormSequence


# check if we are in a colab environment
try:
    from google.colab import output
    output.enable_custom_widget_manager()
    IN_COLAB = True
except ImportError:
    pass


from IPython.display import display, SVG, clear_output
import ipywidgets as widgets

import chess
import chess.svg


subset = []

In [25]:
# Load and subset the dataset

TRAIN_SIZE = 100

def clean_chess_moves(movetext):
    # Remove move numbers (e.g., '1.', '2.') using regex
    movetext = re.sub(r'\d+\.', '', movetext)

    # Remove final score (e.g., '1-0', '0-1', or '1/2-1/2')
    movetext = re.sub(r'\b(1-0|0-1|1/2-1/2)\b', '', movetext)

    # Remove extra spaces
    movetext = ' '.join(movetext.split())

    return movetext

if len(subset) == 0:
  ds = load_dataset("Lichess/standard-chess-games", split='train', streaming=True)
  for idx, example in enumerate(tqdm(ds, desc="Loading dataset")):
      subset.append(example['movetext'])
      if idx == TRAIN_SIZE:
          break

chess_dataset = HFDataset.from_dict({'text': subset})
chess_dataset = chess_dataset.map(lambda x: {'text' : clean_chess_moves(x['text'])})

# save to file 

save_path = "data/chess_dataset.txt"
with open(save_path, "w") as f:
    for movetext in chess_dataset: 
        f.write(movetext['text'])
    
print(f"Dataset saved to {save_path}")

Map: 100%|██████████| 101/101 [00:00<00:00, 11125.77 examples/s]

Dataset saved to data/chess_dataset.txt





In [26]:
# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'mps' if torch.mps.is_available() else device
print(f"Using device: {device}")

if device == 'cuda':
    torch.cuda.empty_cache()

Using device: mps


In [27]:
# let's train a smaller tokenizer 

tokenizer = Tokenizer(models.WordLevel(unk_token="[UNK]"))  

# Custom normalizer to ensure + and # are attached to moves
normalizer = NormSequence([
    Replace(r"(\S+)\s+([+#])", r"\1\2")  
])
tokenizer.normalizer = normalizer

# Custom pre-tokenizer sequence
pre_tokenizer_sequence = Sequence([
    # First split by whitespace
    Split(pattern=" ", behavior="removed"),
    # Then merge + and # with previous token if they were split
    Split(pattern=r"([+#])", behavior="merged_with_previous")
])

tokenizer.pre_tokenizer = pre_tokenizer_sequence

trainer = trainers.WordLevelTrainer(
    vocab_size=1000,
    special_tokens=["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"],
    min_frequency=1,
    show_progress=True
)

# Train tokenizer on the chess-dataset
tokenizer.train(files=["data/chess_dataset.txt"], trainer=trainer)

# Save the trained tokenizer to a file (not a directory)
tokenizer.save("data/custom_chess_tokenizer.json")

# 2. Wrap the tokenizer in a Hugging Face-compatible wrapper
hf_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
hf_tokenizer.save_pretrained("data/custom_chess_tokenizer")  # Now save to a directory for Hugging Face compatibility

# 3. Load and use the tokenizer with AutoTokenizer
trained_tokenizer = AutoTokenizer.from_pretrained("data/custom_chess_tokenizer")
trained_tokenizer.add_special_tokens({'pad_token': '[PAD]', 'eos_token' : '[EOS]'})
new_tokens = ["O-O", "O-O-O"]
trained_tokenizer.add_tokens(new_tokens)

# Test the trained tokenizer
test_text = "e4 e6 d4 b6 a3 Bb7 Nc3+ O-O"


tokens = trained_tokenizer.tokenize(test_text)
print(tokens)

# print vocabulary
vocab = trained_tokenizer.get_vocab()
print(f"Vocabulary size: {len(vocab)}")
print(list(vocab.keys())[:100])


['e4', 'e6', 'd4', 'b6', 'a3', 'Bb7', '[UNK]', 'O-O']
Vocabulary size: 1001
['Rcd1', 'Ng5+', 'Rfa8', 'bxc6', 'Bxc6', 'Rxf4', 'Ndxf2', 'Bxg4', 'cxd6', 'Qxf6+', 'Rxb5', 'Re7', 'Qe6', 'Qh7#e4', 'Kxf2', 'Qxa8', 'Rxh7', 'Kxb3', 'Rhb1', 'Rf1+', 'Rxc1+', 'axb6', 'Nxf1', 'Rad1+', 'Rxc3+', 'h2', 'Rxc6', 'Kc5e4', 'Nd7+e4', 'dxc6', 'Rf4#e4', 'fxe4', 'Nc1', 'Rhe1', 'Nxg6', 'Bc4', 'Nec1', 'Bxe6+', 'Rxc2+', 'Bg2', 'Bxh8e4', 'Rxa3', 'fxg3+', 'Rc2+', 'Bd7', 'Kb7', 'Kxe3', 'Qxa1+', 'Rd2+', 'Qxh1', 'Qg4+', 'Rxf7', 'Bxe6', 'Qxd5', 'Nf6', 'Qa4+', 'Bb4+', 'Qe4e6', 'Nh5', 'Nd5+', 'Qe5', 'Rbf8', 'Rc2#e4', 'Rxb6', 'Nbc6', 'Nfxd5', 'Qxc6', 'Rxe5', 'Nc1d3', 'Rxh3', 'h7', 'Nxe6', 'Nb5', 'Bf1', 'Rad1', 'Rb8+', 'Rbb3', 'Bxd2', 'bxa4', 'Bxd4+', 'Rdf8', 'Nxd2', 'Kc2', 'Bf2#e4', 'Kh7', 'Rh5', 'cxb3', 'Be7', 'cxb6', 'Re8+', 'Rc1+', 'Qe8e7+', 'Nh4', 'Qd3', 'Nxd3', 'Nbd2', 'Qxe4+', 'Bb5', 'Qd1#e4', 'Qg2+']


In [28]:

class LinearTokenPredictor(nn.Module):
    """
    This model is a simple linear model that predicts the next token in a sequence.

    Originally three layers:
    1. An input embedding layer with dimension $d=256$
    2. A linear layer mapping dxT to dxT (with standard masking for the next tokens during training)
    3. An output embedding layer mapping vectors of dimension $d$ to the vocabulary.

    But I also added a layer norm between 2 and 3 to help a bit.
    """
    def __init__(self, tokenizer, vocab_size:int , context_size: int=64, d:int =256, device:str = 'cuda'):
        super(LinearTokenPredictor, self).__init__()
        self.vocab_size = vocab_size
        self.d = d
        self.context_size = context_size
        linear_dim = context_size*d
        self.device = device
        self.tokenizer = tokenizer


        self.embedding = nn.Embedding(vocab_size, d)
        self.linear = nn.Parameter(torch.randn(linear_dim, linear_dim) / np.sqrt(linear_dim))
        self.output = nn.Linear(d, vocab_size, bias=False)

        self.layer_norm = nn.LayerNorm(linear_dim)

        self.mask = self.create_mask(d, context_size).T.to(device)


    def forward(self, x:torch.Tensor):
        """
        For training, we use a causal mask to limit the
        linear layer to only consider previous tokens.
        """
        x = self.embedding(x)


        # map from batch x seq x d to batch x (seq*d)
        x = x.view(x.size(0), -1)

        x = x @(self.linear*self.mask)
        x = self.layer_norm(x) # small addition

        # map back to batch x seq x d
        x = x.view(x.size(0), -1, self.d)
        x = self.output(x)

        return x

    def generate(self, token_list: torch.Tensor | str, n:int = 1,
                 return_html=True, html_font_size:int=12):
        """
        Given a tensor of token-ids, generate n tokens.
        """
        if isinstance(token_list, str):
          token_list = self.tokenizer(token_list)
          token_list = token_list['input_ids']
        else:
          token_list = token_list.tolist()

        len_list  = len(token_list)

        if len_list < self.context_size:
            token_list = token_list + [self.tokenizer.eos_token_id] * (self.context_size - len(token_list))


        with torch.no_grad():
          for i in range(n):
              # keep token list within context by using the last T tokens
              x = torch.tensor(
                  token_list[-self.context_size:]
                  ).unsqueeze(0).to(self.device)
              logits = self.forward(x)

              if len_list + i - 1 < self.context_size:
                curr_token_index = len_list + i - 1 # -1 because the first logit corresponds to P(token_1|token_0)
              else:
                curr_token_index = -1

              logits = logits[:,curr_token_index]
              tok = torch.argmax(logits, dim=-1).item()

              if (len_list + i) < self.context_size:
                token_list[len_list + i] = tok
              else:
                token_list.append(tok)


        answer = self.tokenizer.decode(token_list[len_list:], skip_special_tokens=False)
        if return_html:
          prompt = self.tokenizer.decode(token_list[:len_list], skip_special_tokens=True)
          return (f'<span style="font-size: {html_font_size}px"> <span style="color: green;">'
                  + prompt + '</span> '
                  + answer + '</span></br>')
        else:
          return answer


    def create_mask(self, d:int , T:int):
      mask = np.tril(np.ones((T, T)))
      expanded_mask = np.kron(mask, np.ones((d, d)))
      expanded_mask = torch.tensor(expanded_mask, dtype=torch.float32)
      return expanded_mask



In [29]:
train_size = TRAIN_SIZE
MAX_LENGTH = 128



def truncate(text, max_length):
    """
    Randomly shuffle the start of the text to create different starting points,
    and truncate to max_length if necessary.
    """
    tokens = text.split()

    return ' '.join(tokens[:max_length])


# TODO remove numbers from text, only keep moves.

processed_ds = chess_dataset.map(lambda x:
                                      {'text': truncate(x['text'], MAX_LENGTH)})

train_set = processed_ds.map(lambda x: trained_tokenizer(x['text'],
                                                 padding='max_length',
                                                 max_length=MAX_LENGTH,
                                                 truncation=True,
                                                 return_tensors='pt'), batched=True)

train_set.set_format(type='torch', columns=['input_ids'])
train_loader = torch.utils.data.DataLoader(train_set,
                                           batch_size=1024, shuffle=True)

Map: 100%|██████████| 101/101 [00:00<00:00, 2817.68 examples/s]
Map: 100%|██████████| 101/101 [00:00<00:00, 2181.90 examples/s]


In [30]:
vocab_size = trained_tokenizer.vocab_size
print(vocab_size)

context_size = MAX_LENGTH -1
d = 32 # if train_size is small, you may want to keep this small as well

model = None
model = LinearTokenPredictor(trained_tokenizer,
                             len(trained_tokenizer),
                             context_size=context_size,
                             d=d, device=device).to(device)

1000


In [31]:
print(f'parameter count: {sum(p.numel() for p in model.parameters())}')

print(f'Using device: {device}')


criterion = nn.CrossEntropyLoss(ignore_index = trained_tokenizer.pad_token_id)
optimizer = Prodigy(model.parameters())

loss_tracking = []

EPOCHS = 100
TOL_EARLY_STOP = 0.00001

for epoch in tqdm(range(EPOCHS)):
    for batch in train_loader:
        input_ids = batch['input_ids']

        inputs = input_ids[:, :-1]  # All tokens except the last one
        targets = input_ids[:, 1:] # All tokens except the first one

        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        logits = model(inputs).permute(0,2,1)

        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()

        loss_tracking.append(loss.item())

    last_loss = sum(loss_tracking)/len(loss_tracking)

    if epoch % 10 == 0:
      print(f'Epoch {epoch} loss: {last_loss}')

    if last_loss < TOL_EARLY_STOP:
      break

    loss_tracking = []

parameter count: 16588288
Using device: mps


  3%|▎         | 3/100 [00:05<02:12,  1.37s/it]

Epoch 0 loss: 7.003628730773926


 13%|█▎        | 13/100 [00:05<00:14,  5.95it/s]

Epoch 10 loss: 4.917160511016846


 23%|██▎       | 23/100 [00:06<00:06, 11.34it/s]

Epoch 20 loss: 0.34555402398109436


 33%|███▎      | 33/100 [00:07<00:04, 14.18it/s]

Epoch 30 loss: 0.06236353516578674


 43%|████▎     | 43/100 [00:08<00:03, 14.52it/s]

Epoch 40 loss: 0.037004031240940094


 53%|█████▎    | 53/100 [00:08<00:03, 13.77it/s]

Epoch 50 loss: 0.023584822192788124


 63%|██████▎   | 63/100 [00:09<00:02, 14.04it/s]

Epoch 60 loss: 0.016063058748841286


 73%|███████▎  | 73/100 [00:10<00:01, 14.44it/s]

Epoch 70 loss: 0.011385445483028889


 83%|████████▎ | 83/100 [00:10<00:01, 14.63it/s]

Epoch 80 loss: 0.008320871740579605


 93%|█████████▎| 93/100 [00:11<00:00, 14.05it/s]

Epoch 90 loss: 0.0062261163257062435


100%|██████████| 100/100 [00:12<00:00,  8.25it/s]


In [32]:
HTML(model.generate(trained_tokenizer.encode("d4 d6",

                                                 return_tensors='pt' ).squeeze(),
                    n=50))

In [33]:

i = 2


def parse_chess_moves(game_string):
    # Remove move numbers and extra spaces
    cleaned_moves = re.sub(r"\d+\.\s*", "", game_string).strip()
    # Split moves into a list
    move_list = cleaned_moves.split()
    return move_list

print(trained_tokenizer.decode(inputs[i]))
HTML(model.generate(inputs[i][0:1], n=100, return_html=True, html_font_size=15))

e4 e5 Nf3 Nc6 Bc4 Nf6 Nc3 Bc5 a3 Bxf2+ Kxf2 Nd4 d3 Ng4+ Kf1 Qf6 h3 d5 Nxd5 Qe6 [UNK] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


In [None]:
if IN_COLAB:
    # Initialize the chess board
    board = chess.Board()
    tokenizer = trained_tokenizer

    def model_move(model, board, tokenizer, n, game):
        print(game)
        generated_moves = model.generate(game, n=n, return_html=False).split()
        print(generated_moves)

        for move in generated_moves:
            try:
                board.push_san(move)
                return move
            except ValueError:
                continue
        return None

    # Create a vertical box layout for the game display
    game_display = widgets.Output()
    controls = widgets.VBox([
        widgets.Text(
            value='',
            placeholder='Type e4, Nf3, etc.',
            description='Move:',
            disabled=False
        ),
        widgets.Button(description='Make Move')
    ])

    # Create the main layout
    main_box = widgets.VBox([game_display, controls])

    # Game state
    game_state = {'game': '', 'is_game_over': False}

    def update_display():
        with game_display:
            clear_output(wait=True)
            display(SVG(chess.svg.board(board, size=400)))

    def on_button_click(b):
        if game_state['is_game_over']:
            return

        player_move = controls.children[0].value  # Get move from text input

        try:
            board.push_san(player_move)
            game_state['game'] += player_move + ' '

            update_display()
            print("Your move:", player_move)

            if board.is_game_over():
                print("Game Over!")
                print("Result:", board.result())
                game_state['is_game_over'] = True
                return

            # Model's turn
            print("Model's turn...")
            model_move_str = model_move(model, board, tokenizer, 1, game_state['game'])

            if model_move_str is None:
                print("Model failed to generate a move. Game over.")
                game_state['is_game_over'] = True
                return

            game_state['game'] += model_move_str + ' '
            print(f"Model plays: {model_move_str}")

            update_display()

            if board.is_game_over():
                print("Game Over!")
                print("Result:", board.result())
                game_state['is_game_over'] = True
                return

        except ValueError:
            print("Invalid move! Try again.")

        controls.children[0].value = ''  # Clear input field

    # Connect the button click event
    controls.children[1].on_click(on_button_click)

    # Initial display
    print("Welcome to Chess! You are playing as White. Enter moves in algebraic notation (e.g., e4, Nf3).")
    update_display()

    # Show the interface
    display(main_box)

Welcome to Chess! You are playing as White. Enter moves in algebraic notation (e.g., e4, Nf3).
