In [None]:
# Author: Kerem Kazan
# Title: Chess Commentator Transformer - Training

In [None]:
# !pip install transformers==4.45.2
# !pip install chess

## Disclaimer

This is a post-annotation of a google colab training run. The original can be found here: https://colab.research.google.com/drive/1tCtRpZRnNv6Ta6qG9ZFcOab-E2_0NbpW?usp=sharing 

In [3]:
# We need to mount google drive to the colab filesystem so that we can save our model checkpoints.

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# We need to 

DATASET_ROOT = "/content/drive/MyDrive/chess-7"
DF_FILE = f"{DATASET_ROOT}/simple_encoder_dataset.csv"
DECODER_TOKENIZER_DIR = f"{DATASET_ROOT}/simple_decoder_tokenizer"
ENCODER_TOKENIZER_DIR = f"{DATASET_ROOT}/simple_encoder_tokenizer"
OUTPUT_DIR = f"{DATASET_ROOT}/output-26"

In [5]:
# Let's first load our dataframe:

import pandas as pd

df = pd.read_csv(DF_FILE)

print(f"Loaded {len(df)} rows")

Loaded 267877 rows


In [6]:
# Let's create a helper visualization function that will let us inspect the data.

import chess
import chess.svg
from IPython.display import display, HTML

def display_data_item(data_item):
  print(data_item["fen_before"])
  board_before = chess.Board(data_item['fen_before'])
  move = chess.Move.from_uci(data_item['uci'])
  # turn = "White" if data_item['is_white'] == "True" else "Black"
  turn="White"
  size=390

  svg = chess.svg.board(
    board_before,
    arrows=[chess.svg.Arrow(move.from_square, move.to_square)],
    size=size
  )

  # Create HTML to display boards side by side
  html = f"""
  <div>
    <table>
      <tr>
        <td>Move</td>
        <td>{data_item['uci']}</td>
      </tr>
      <tr>
        <td>Move #</td>
        <td>{data_item['full_move_number']}</td>
      </tr>
      <tr>
        <td>Turn</td>
        <td>{turn}</td>
      </tr>
      <tr>
        <td>Comment</td>
        <td>{data_item['comment']}</td>
      </tr>
    </table>
    <div style="display: flex; gap: 20px;">
      <svg width="{size}" height="{size}">{svg}</svg>
    </div>
  </div>
  """
  display(HTML(html))


In [7]:
display_data_item(df.iloc[0])

rnb1kbnr/ppp2ppp/8/3qp3/8/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 4


0,1
Move,b1c3
Move #,4
Turn,White
Comment,my basic opening.


In [8]:
from transformers import BertConfig, BertModel
from transformers import PreTrainedTokenizerFast
from transformers import EncoderDecoderModel, GPT2Config, GPT2LMHeadModel
from transformers import GPT2Tokenizer


# We load the tokenizers we trained in the previous notebook.

decoder_tokenizer = PreTrainedTokenizerFast.from_pretrained(DECODER_TOKENIZER_DIR)
encoder_tokenizer = PreTrainedTokenizerFast.from_pretrained(ENCODER_TOKENIZER_DIR)


0

In [None]:
# [BOS] = Beginning of sentence
# [EOS] = End of sentence
# [PAD] = padding token

# We explicitly add these tokens to our decoder tokenizer because we forgot to add them in the previous notebook. These are necessary for the decoder to understand where a comment begins, where it ands, and which tokens to ignore.

decoder_tokenizer.add_special_tokens({
  "bos_token": "[BOS]",
  "eos_token": "[EOS]",
  "pad_token": "[PAD]",
})

In [9]:
# We need to explicitly extend the comment field with an EOS token. Otherwise the transformer does not learn when to stop.

df["comment"] = df["comment"].str.lower().apply(lambda x: f"{x} {decoder_tokenizer.eos_token}")

In [10]:
# We then go through our dataset to find maximum lenght of a tokenized comment. We will later use this in our neural network params, as well as output generation.

max_decoded_length = df["comment"].apply(lambda x: len(decoder_tokenizer.encode(x, add_special_tokens=True))).max()

max_decoded_length+=5 # padding for safety
print(f"Max decoded length: {max_decoded_length}")

Max decoded length: 153


In [11]:
# Same idea but for the encoder layer.

max_encoded_length = df["input_sequence"].apply(lambda x: len(encoder_tokenizer.encode(x, add_special_tokens=True))).max()
max_encoded_length+=5 # padding for safety
print(f"Max encoded length: {max_encoded_length}")

Max encoded length: 150


In [12]:
# We then create our encoder-decoder model. Our architecture choice is BERT for the encoder, and GPT2 for the decoder.

from transformers import BertConfig, BertModel, EncoderDecoderConfig

def get_empty_model():
  # === Encoder Config Parameters ===
  ENCODER_DIM = 256
  ENCODER_NUM_LAYERS = 6
  ENCODER_NUM_HEADS = 8
  ENCODER_INTERMEDIATE_SIZE = ENCODER_DIM * 4
  ENCODER_MAX_POSITION_EMBEDDINGS = max_encoded_length  # based on your input length stats
  ENCODER_VOCAB_SIZE = encoder_tokenizer.vocab_size
  ENCODER_PAD_TOKEN_ID = encoder_tokenizer.pad_token_id

  encoder_config = BertConfig(
      vocab_size=ENCODER_VOCAB_SIZE,
      hidden_size=ENCODER_DIM,
      num_hidden_layers=ENCODER_NUM_LAYERS,
      num_attention_heads=ENCODER_NUM_HEADS,
      intermediate_size=ENCODER_INTERMEDIATE_SIZE,
      max_position_embeddings=ENCODER_MAX_POSITION_EMBEDDINGS,
      pad_token_id=ENCODER_PAD_TOKEN_ID,
  )

  encoder_model = BertModel(config=encoder_config)

  # === Decoder Config Parameters ===
  DECODER_DIM = 256
  DECODER_NUM_LAYERS = 6
  DECODER_NUM_HEADS = 8
  DECODER_MAX_POSITION_EMBEDDINGS = max_decoded_length  # from earlier stats
  DECODER_VOCAB_SIZE = decoder_tokenizer.vocab_size

  DECODER_BOS_TOKEN_ID = decoder_tokenizer.bos_token_id
  DECODER_EOS_TOKEN_ID = decoder_tokenizer.eos_token_id
  DECODER_PAD_TOKEN_ID = decoder_tokenizer.pad_token_id

  decoder_config = GPT2Config(
      vocab_size=DECODER_VOCAB_SIZE,
      n_embd=DECODER_DIM,
      n_layer=DECODER_NUM_LAYERS,
      n_head=DECODER_NUM_HEADS,
      n_positions=DECODER_MAX_POSITION_EMBEDDINGS,
      bos_token_id=DECODER_BOS_TOKEN_ID,
      eos_token_id=DECODER_EOS_TOKEN_ID,
      pad_token_id=DECODER_PAD_TOKEN_ID,
      add_cross_attention=True,  # required for encoder-decoder use
  )

  decoder_model = GPT2LMHeadModel(config=decoder_config)

  encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(
    encoder_config,
    decoder_config
  )

  encoder_decoder_config.decoder_start_token_id = DECODER_BOS_TOKEN_ID
  encoder_decoder_config.decoder_end_token_id = DECODER_EOS_TOKEN_ID
  encoder_decoder_config.pad_token_id = DECODER_PAD_TOKEN_ID
  encoder_decoder_config.cls_token_id = encoder_tokenizer.cls_token_id
  encoder_decoder_config.sep_token_id = encoder_tokenizer.sep_token_id

  model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model, config=encoder_decoder_config)

  return model

In [13]:
# Let's do a quick check sequence to see if everything is wired correctly.

model = get_empty_model()
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

row = df.iloc[0]
input_str = row["input_sequence"]
target_str = row.get("comment", None)

input_tokens = encoder_tokenizer(
    input_str,
    return_tensors="pt",
    padding=True,
).to(device)

with torch.no_grad():
    generated_ids = model.generate(
        **input_tokens,
        max_new_tokens=max_decoded_length,
        num_beams=4,
        early_stopping=True
    )

decoded_output = decoder_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

print(input_tokens)
print(decoded_output)

EncoderDecoderModel has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


{'input_ids': tensor([[96, 26, 36, 25, 46,  2, 96,  5, 12, 26, 56,  2, 96, 18, 98, 26, 52, 97,
         26, 29,  2, 96, 17, 36, 52, 29, 61, 31, 63, 40, 56,  2,  2, 99, 26, 28,
         96, 26, 36, 92, 26, 44, 98, 26, 52, 95, 26, 60, 92, 26, 68, 99, 26, 84,
         97, 26, 29, 97, 26, 37, 97, 26, 45, 97, 26, 53, 97, 26, 69, 97, 26, 77,
         97, 26, 85, 96, 26, 70, 12, 26, 56, 11, 26, 64, 11, 26, 34, 11, 26, 42,
         11, 26, 50, 11, 26, 74, 11, 26, 82, 11, 26, 90, 13, 26, 35, 10, 26, 43,
          6, 26, 51,  9, 26, 67,  6, 26, 75, 10, 26, 83, 13, 26, 91,  2, 93, 94,
          7,  8,  2]], device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 

# Final Training Preparations

In [14]:
# We create our dataset from our dataframe. This one runs encoder and decoder tokenizers on the input and target sequences.

from torch.utils.data import Dataset
import torch

class ChessCommentaryDataset(Dataset):
    def __init__(self, df, encoder_tokenizer, decoder_tokenizer, max_input_length, max_output_length):
        self.df = df.reset_index(drop=True)
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.max_input_length = max_input_length
        self.max_output_length = max_output_length
        self.cls_token_id = decoder_tokenizer.cls_token_id

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        input_str = row["input_sequence"]
        target_str = row["comment"]

        # Encode input
        encoder_inputs = self.encoder_tokenizer(
            input_str,
            padding="max_length",
            truncation=True,
            max_length=self.max_input_length,
            return_tensors="pt",
        )

        # Encode target
        decoder_inputs = self.decoder_tokenizer(
            target_str,
            padding="max_length",
            truncation=True,
            max_length=self.max_output_length,
            return_tensors="pt",
            add_special_tokens=True
        )

        labels = decoder_inputs.input_ids.squeeze(0)

        # Remove leading [CLS] if present (to avoid [CLS][CLS] issue)

        # This one line saved this entire project. My transformer could not even memorize a dataset of 10 rows. Turns out,
        # that happens when you have a [CLS] token at the beginning of your sequence during shifting.

        # More info: https://discuss.huggingface.co/t/from-transformers-version-v4-12-0-onwards-the-example-colab-bert2bert-is-wrong-things-to-keep-in-mind-when-using-from-transformers-import-encoderdecodermodel/73491

        if labels[0].item() == self.cls_token_id:
            labels = labels[1:]
            # Pad to maintain consistent length
            labels = torch.cat([labels, torch.tensor([self.decoder_tokenizer.pad_token_id])], dim=0)

        # Replace pad token with -100 so it's ignored in loss
        labels[labels == self.decoder_tokenizer.pad_token_id] = -100

        input_ids = encoder_inputs.input_ids.squeeze(0)
        attention_mask = encoder_inputs.attention_mask.squeeze(0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


In [16]:
# Now, let's split our dataframe into train and test sets:

train_df = df.sample(frac=0.95, random_state=42)
eval_df = df.drop(train_df.index)

print(f"Train set size: {len(train_df)}")
print(f"Eval set size: {len(eval_df)}")

Train set size: 254483
Eval set size: 13394


In [17]:
train_dataset = ChessCommentaryDataset(
  df = train_df,
  encoder_tokenizer = encoder_tokenizer,
  decoder_tokenizer = decoder_tokenizer,
  max_input_length = max_encoded_length,
  max_output_length = max_decoded_length
)

eval_dataset = ChessCommentaryDataset(
  df = eval_df,
  encoder_tokenizer = encoder_tokenizer,
  decoder_tokenizer = decoder_tokenizer,
  max_input_length = max_encoded_length,
  max_output_length = max_decoded_length
)

In [18]:
# Let's run a quick sanity check to make sure our dataset is working as expected.

original = train_df.iloc[0]

item = train_dataset[0]
labels = item["labels"].tolist()
labels_recovered = []
for label in labels:
  if label == -100:
    break
  labels_recovered.append(label)

print(labels_recovered)

decoded = decoder_tokenizer.decode(
  labels_recovered,
  skip_special_tokens=False
)
print(decoded)

print(original["comment"])


[932, 682, 24, 16, 137, 259, 638, 77, 72, 200, 650, 224, 487, 89, 77, 84, 749, 16, 156, 137, 377, 477, 246, 77, 429, 59, 16, 70, 3]
 thinking bd6. this lets the a rook slide over to the h rank. but this does even up the points. [EOS]
thinking bd6. this lets the a rook slide over to the h rank. but this does even up the points. [EOS]


In [19]:
# Let's inpect the tensor shapes to make sure our dataset and model are working as expected.

from torch.utils.data import DataLoader

test_model = get_empty_model()
data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

for batch in data_loader:
  print(batch["input_ids"].shape)
  print(batch["attention_mask"].shape)
  out_2 = test_model(**batch)
  print(out_2.logits.shape)
  break


torch.Size([1, 150])
torch.Size([1, 150])
torch.Size([1, 153, 1000])


  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)


In [21]:
# We'll make use of this context manager to switch between training and evaluation mode. It's handy in Trainer callbacks.

from contextlib import contextmanager

@contextmanager
def evaluation(model):
    was_training = model.training
    model.eval()
    try:
        yield
    finally:
        if was_training:
            model.train()

# Critical Takeaway

In the earlier iterations of this project, my transformer would very quickly collapse and keep generating the same output - like "A very good move" - regardless of the input. It's much easier to debug issues like tensor size mismatches - the runtime will error and print the error message and a stack trace. But this is one of those more Byzantine issues that are hard to debug. Training a model is no easy task - especially for a beginner. You need to be patient and keep fixing issues one by one. That's why it's extremely important to have mental checkpoints and sanity checks that gradually become more complex.

Here's a very useful tool to help with that: Create a super small training dataset and attempt overfit it. If your model can't even memorize a dataset of 10 rows, no amount of hyperparameter tuning will help. You have a systemic issue, and likely a bug. This was the case for me. The next cell contains a sanity check that I used to debug this issue.

In [22]:
# Sanity check training:
from transformers import TrainingArguments, Trainer

# Overfit sanity check
small_train_df = train_df.sample(10, random_state=42).reset_index(drop=True)
small_eval_df = eval_df.sample(1, random_state=55).reset_index(drop=True)

small_dataset = ChessCommentaryDataset(
  df=small_train_df,
  encoder_tokenizer=encoder_tokenizer,
  decoder_tokenizer=decoder_tokenizer,
  max_input_length=max_encoded_length,
  max_output_length=max_decoded_length,
)

small_eval_dataset = ChessCommentaryDataset(
  df=small_eval_df,
  encoder_tokenizer=encoder_tokenizer,
  decoder_tokenizer=decoder_tokenizer,
  max_input_length=max_encoded_length,
  max_output_length=max_decoded_length,
)

small_dataset[0]

{'input_ids': tensor([92, 26, 86, 25, 58,  2, 92, 15, 10, 26, 58,  2, 92,  5,  6, 26, 51,  2,
         92, 17, 86, 31, 79, 40, 72, 49, 65, 51, 67,  2,  2, 99, 26, 28, 95, 26,
         60, 99, 26, 84, 97, 26, 29, 97, 26, 37, 98, 26, 61, 97, 26, 69, 97, 26,
         85, 97, 26, 78, 92, 26, 86, 11, 26, 55, 96, 26, 63, 92, 26, 71, 97, 26,
         56, 13, 26, 64, 11, 26, 81, 11, 26, 34, 11, 26, 42, 11, 26, 50, 10, 26,
         58,  9, 26, 66,  6, 26, 82, 11, 26, 90, 13, 26, 35,  6, 26, 51, 12, 26,
         59,  2, 93, 94,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [23]:
from transformers import TrainerCallback
import torch
import random

# Here's another extremely important tool: Don't just count on training loss and eval loss to tell you if your model is learning. Transformers can be very deceptive. There were many cases where the training loss approached 0 during my sanity-check training runs - but the model was still not learning. This is why it's important to have a mechanism to clearly show the output of the model. 

# Huggingface trainer library has a callback mechanism that allows you to run custom code at certain points in the training process. The class below is our way to print out a sample of the model's output during evaluation.

class CommentGenerationCallback(TrainerCallback):
    def __init__(self, encoder_tokenizer, decoder_tokenizer, dataset, num_examples=3):
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.dataset = dataset
        self.num_examples = num_examples

    def on_evaluate(self, args, state, control, model=None, **kwargs):
        print("\n🔍 Generating sample comments from eval dataset...")

        with evaluation(model):
            indices = random.sample(range(len(self.dataset)), min(self.num_examples, len(self.dataset)))
            for i in indices:
                example = self.dataset[i]

                # Convert input_ids to text if raw text is unavailable
                if isinstance(example, dict) and "input_ids" in example:
                    input_ids = example["input_ids"].unsqueeze(0).to(args.device)
                    attention_mask = example["attention_mask"].unsqueeze(0).to(args.device)
                    gold_label_ids = example["labels"]
                    input_text = self.encoder_tokenizer.decode(example["input_ids"], skip_special_tokens=True)
                else:
                    # Or just use the raw string input
                    input_text = example["input_sequence"]
                    encoded = self.encoder_tokenizer(
                        input_text,
                        return_tensors="pt",
                        truncation=True,
                        padding="max_length",
                        max_length=max_encoded_length
                    ).to(args.device)
                    input_ids = encoded["input_ids"]
                    attention_mask = encoded["attention_mask"]
                    gold_label_ids = self.decoder_tokenizer(
                        example["comment"],
                        padding="max_length",
                        truncation=True,
                        max_length=max_decoded_length,
                        return_tensors="pt"
                    ).input_ids.squeeze().tolist()

                # Generate comment
                with torch.no_grad():
                    generated_ids = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        decoder_start_token_id=model.config.decoder_start_token_id,
                        max_new_tokens=max_decoded_length,
                        num_beams=4,
                        early_stopping=True,
                    )

                generated_comment = self.decoder_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                gold_comment = self.decoder_tokenizer.decode(
                    [t for t in gold_label_ids if t != -100 and t != self.decoder_tokenizer.pad_token_id],
                    skip_special_tokens=True
                )

                print(f"🟢 Predicted: {generated_comment}")
                print(f"🔴 Ground   : {gold_comment}")
                print("")
            print("-" * 80)


In [25]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# We set batch sizes to 1, because we actually want to overfit here. Again, it's just a sanity check.

sanity_check_model = get_empty_model()

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    eval_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    bf16=True,
    output_dir="./",
    logging_steps=100,
    save_steps=100000,
    num_train_epochs=1000,
    # logging_steps=1000,
    # save_steps=500,
    eval_steps=200,
    report_to="none",
    # warmup_steps=2000,
    # save_total_limit=3,
    generation_max_length=max_decoded_length,
)

trainer = Seq2SeqTrainer(
    model=sanity_check_model,
    args=training_args,
    train_dataset=small_dataset,
    eval_dataset=small_dataset,
    tokenizer=decoder_tokenizer,
    callbacks=[
        CommentGenerationCallback(
            encoder_tokenizer=encoder_tokenizer,
            decoder_tokenizer=decoder_tokenizer,
            dataset=small_dataset,
            num_examples=3
        )
    ]
)
trainer.train()

  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)


Step,Training Loss,Validation Loss
200,3.3831,2.559287
400,1.8359,1.29012
600,0.9724,0.63388
800,0.4642,0.274225
1000,0.2033,0.111641
1200,0.0887,0.050303
1400,0.0462,0.028845
1600,0.0306,0.01986
1800,0.0224,0.014951
2000,0.0173,0.011729



🔍 Generating sample comments from eval dataset...
🟢 Predicted:  winning the knight. 
🔴 Ground   :  line up the rook to save mine... 

🟢 Predicted:  winning the knight. 
🔴 Ground   :  objectively strongest, but 16 0-0, securing the king and bringing the rook into play was a viable option. the straightforward 16 bxe5 was also winning for white. the added umph from 16 bxd7 is the weakening of e5 and giving access to c5 to the white knight. 

🟢 Predicted:  winning the knight. 
🔴 Ground   :  now if i make a move like 38) kg4, he will have the possibility of 38) ... rf2, winning either the a-pawn or the h-pawn. 

--------------------------------------------------------------------------------

🔍 Generating sample comments from eval dataset...
🟢 Predicted:  winning the knight. 
🔴 Ground   :  i was getting quite excited here. i was thinking that he would retreat the bf4 allowing me to play e5-e4 forking his knight and bishop and winning a piece. he comes up with the correct reply. 

🟢 Predict

KeyboardInterrupt: 

In [None]:
# Don't mind the KeyboardInterrupt - this is intentional. If you scroll all the way down in the last cell, you'll see that by step 2000, the model perfectly matches the dataset. So the model is at least able to extract some form of meaning from the dataset. This is a good sign, and it's our last sanity check before we move on to the real training.

In [26]:
model = get_empty_model()

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# We set the batch size to 64. Training on A100, there's plenty of memory to spare.

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    eval_strategy="steps",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    bf16=True,
    output_dir=OUTPUT_DIR,
    logging_steps=100,
    save_steps=500,
    num_train_epochs=1000,
    eval_steps=400,
    report_to="none",
    warmup_steps=2000,
    save_total_limit=2,
    generation_max_length=max_decoded_length,
    logging_strategy="steps",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=decoder_tokenizer,
    callbacks=[
        CommentGenerationCallback(
            encoder_tokenizer=encoder_tokenizer,
            decoder_tokenizer=decoder_tokenizer,
            dataset=eval_dataset,
            num_examples=3
        )
    ]
)
trainer.train()

In [40]:
# after about 150,000 steps with batch size 64, we have achieved a training loss of 2.53,
# and a validation loss of 2.62. The difference is not high, so there's likely no overfitting.
# However the loss numbers are high, so are likely underfitting. But let's take a look now:

def predict_and_print(ds_item):
  with evaluation(model):
    input_ids = ds_item["input_ids"].unsqueeze(0).to(device)
    attention_mask = ds_item["attention_mask"].unsqueeze(0).to(device)
    generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_start_token_id=model.config.decoder_start_token_id,
            max_new_tokens=max_decoded_length,
            num_beams=4,
            early_stopping=True,
        )

    generated_comment = decoder_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"🟢 Predicted: {generated_comment}")
    print(f"🔴 Ground   : {data_item['comment']}")
    print("")

# Let's modify our helper visualization function to make it a bit easier to see
# the difference between prediction and ground truth:

import chess
import chess.svg
from IPython.display import display, HTML

def display_data_item(data_item):
  board_before = chess.Board(data_item['fen_before'])
  move = chess.Move.from_uci(data_item['uci'])
  # turn = "White" if data_item['is_white'] == "True" else "Black"
  turn="White"
  size=390

  svg = chess.svg.board(
    board_before,
    arrows=[chess.svg.Arrow(move.from_square, move.to_square)],
    size=size
  )

  html = f"""
  <div>
    <div style="display: flex; gap: 20px;">
      <svg width="{size}" height="{size}">{svg}</svg>
    </div>
  </div>
  """
  display(HTML(html))


# Now let's pick 50 random rows from the evaluation set and see if the model can generate logical comments for them:
for i in range(50):
  data_item = eval_df.iloc[i]
  ds_item = eval_dataset[i]
  predict_and_print(ds_item)
  display_data_item(data_item)
  print("_" * 80)
  print("")

🟢 Predicted:  cxd4 - pawn trade. 
🔴 Ground   : i won't have played this one as white. [EOS]



________________________________________________________________________________

🟢 Predicted:  b5 - i now start a pawn push on the queenside. 
🔴 Ground   : i overlooked here a tactic i would have certainly seen being the white player: nc6 is undefended after nxb5 or nxd5. [EOS]



________________________________________________________________________________

🟢 Predicted:  preparing to push the c-pawn. 
🔴 Ground   : a waste of a move, following a simple exchange game with no... well, let's see. [EOS]



________________________________________________________________________________

🟢 Predicted:  i take the knight with my bishop. 
🔴 Ground   : bye bye, dearest! your blind brother will spread his wings thanks to your sacrifice. [EOS]



________________________________________________________________________________

🟢 Predicted:  the bishop comes back to d1 to attack the b pawn. 
🔴 Ground   : kf7 was necessary and probably enough to draw. [EOS]



________________________________________________________________________________

🟢 Predicted:  the king comes across. 
🔴 Ground   : after this nonsense dance, i'll lose another pawn. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a good move. it blocks the queen's line of fianchetto. 
🔴 Ground   : block thed file. not good. [EOS]



________________________________________________________________________________

🟢 Predicted:  i retreat my bishop to safety. 
🔴 Ground   : keeping the bishop on the d8-h4 diagonal to attack squares near the black king. [EOS]



________________________________________________________________________________

🟢 Predicted:  i take the bishop. 
🔴 Ground   : allowing black to win a pawn at the expense of an exchange reducing material left on the board. [EOS]



________________________________________________________________________________

🟢 Predicted:  both sides have castled. 
🔴 Ground   : its very good to castle early in the game. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a good move. it allows me to double my rooks on the b-file. 
🔴 Ground   : really bad move here [EOS]



________________________________________________________________________________

🟢 Predicted:  the only move. 
🔴 Ground   : look at this position! the black king is stalemated, white just needs a check and its mate...even without queens mating attacks occur! [EOS]



________________________________________________________________________________

🟢 Predicted:  now the bishop is pinned to the king. 
🔴 Ground   : expecting an exchange [EOS]



________________________________________________________________________________

🟢 Predicted:  the only move. 
🔴 Ground   : scattered rooks! [EOS]



________________________________________________________________________________

🟢 Predicted:  the exchange variation. 
🔴 Ground   : if 3... c6 4. nc3 cxd5 5. bf4 e6 (5... bf5 6. qb3) 6. e3 bd6 7. bxd6 qxd6 8. nf3 or 8. rc1. [EOS]



________________________________________________________________________________

🟢 Predicted:  this is a mistake, as it opens up the e-file for my rook. 
🔴 Ground   : makes things worse. [EOS]



________________________________________________________________________________

🟢 Predicted:  axb4 - pawn trade. 
🔴 Ground   : black activates the ra8. now the rook is hitting the a3 pawn. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake. i think black should have taken the pawn with the knight. 
🔴 Ground   : black strengthens his knight on f5. this is because now white is discouraged from playing g4 and kicking the knight out. white can still play g4, though, because of his h3 pawn. [EOS]



________________________________________________________________________________

🟢 Predicted:  kf2 - henry moves towards the centre. 
🔴 Ground   : it doesn't seem like black will be able to get anywhere, as the white position is very solid. so black sees that the only way to make a breakthrough is to eventually play rxa3, and then after nxa3 rxa3, to achieve a winning rook and pawn endgame. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this was a mistake. i was expecting 38...rc4, but i didn't see it. 
🔴 Ground   : black wastes a move, allowing white to move his king away from the defense of the g3 pawn. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is an inaccuracy because it weakens the e5 square. 
🔴 Ground   : this seems to be a weakening pawn move, preparing for ...e5 but leaving the king's pawn weak [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake because it allows me to castle queen side. 
🔴 Ground   : this move seems to actually hinder black's development, since the white square bishop doesn't gain much in the way of development and severely limits the mobility of both the queen and the knight. [EOS]



________________________________________________________________________________

🟢 Predicted:  b3 - to prevent a knight check on c4. 
🔴 Ground   : now i'm on the defensive, and poised to be picked apart by some precise play. [EOS]



________________________________________________________________________________

🟢 Predicted:  black resigned here. many thanks for reading, please leave a comment or two and rate it on the star system. until next time dear reader! 
🔴 Ground   : according the to engine, this (and move 57) is where i entirely lost my advantage. it suggests nc6, kg4, nc4, ke4, or nd3... pretty much any move but the one i made. [EOS]



________________________________________________________________________________

🟢 Predicted:  the only move. 
🔴 Ground   : the worst place for the king, as he will be exposed to a crushing attack by black. however 15.kf1 loses the queen to nxe3 and 15. kf2 is met with ... qxe3 15. kf1 nxa1 and white is lost in any case. [EOS]



________________________________________________________________________________

🟢 Predicted:  kxd3 - rook trade. 
🔴 Ground   : missed the mate in two, however 16. kc1 leads soon to a mate as well [EOS]



________________________________________________________________________________

🟢 Predicted:  the standard move in the king's indian attack. 
🔴 Ground   : first time i have played against this opening so i decide to just develop [EOS]



________________________________________________________________________________

🟢 Predicted:  white develops his lb to e2 and prepares to castle. 
🔴 Ground   : he removes the pin [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a good move, but it does open up the e-file for my rook. 
🔴 Ground   : if he takes with d5 i take d2 [EOS]



________________________________________________________________________________

🟢 Predicted:  i take the pawn with my queen. 
🔴 Ground   : now my pawns on d4 and e5 have been replaced by pieces which are powerfully placed. perhaps black should try and exchange them with say 18.... qf6? white will always have positional trumps in this sort of position if he plays carefully, as the black e pawn is now a backward pawn on a semi-open file, (i.e. white can attack it with rook and/or queen down the e file), and the black bishop is bad. it is to improve this bishop that my opponent tries next. [EOS]



________________________________________________________________________________

🟢 Predicted:  black develops his last minor piece. 
🔴 Ground   : this is played with an eye to setting up operations on the longest diagonal. [EOS]



________________________________________________________________________________

🟢 Predicted:  black develops his knight to f6. 
🔴 Ground   : black prepares castling, but with a small positional disadvantage: the a1-h8 diagonal is open. [EOS]



________________________________________________________________________________

🟢 Predicted:  white develops his knight to f3. 
🔴 Ground   : out comes this knight, having a certain idea in mind. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a good move. it opens the diagonal for the bishop, but it isn't going anywhere. 
🔴 Ground   : most likely because this was going to be played next anyway. i have no intentions of castling short, so i may as well continue in storming up the kingside in the meantime. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake, as it allows the rook to move to d1, which would have allowed me to move my queen to d2. 
🔴 Ground   : now i'm in a position where black's intentions will be indicated by any piece that moves next. as far as the central pawns are concerned, if it is clear that d4 is being targeted, both rooks will stay here. if e5 is targetted, one of them must move back. [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake, as it allows me to double my rooks on the c-file. 
🔴 Ground   : i imagine that black must have figured that i sealed my fate of guarding d4 at this point, now that the king completely blocks r2d2 to join the battle. however! [EOS]



________________________________________________________________________________

🟢 Predicted:  this is a mistake. white should have taken the pawn with the queen. 
🔴 Ground   : should that be black's plan, i'll be ready to launch a full assault on f7; i finally sat myself down and saw a working plan to blast open the kingside! now my attack is ready for action. meanwhile, though, i'm pretty sure that given the next set of moves, black had other thoughts in their mind... [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake, as it allows me to regain control of the a file. 
🔴 Ground   : have i finally considered accepting the draw? [EOS]



________________________________________________________________________________

🟢 Predicted:  now the bishop is pinned to the king. 
🔴 Ground   : it cannot save the king. there are too many squares i can threaten at this point. [EOS]



________________________________________________________________________________

🟢 Predicted:  bd2 - fou lenoir is driven back to d2. 
🔴 Ground   : my pieces are getting pushed around [EOS]



________________________________________________________________________________

🟢 Predicted:  black resigns. 
🔴 Ground   : here i expected and feared g4... i think black would have won with that move. and the computer says so. [EOS]



________________________________________________________________________________

🟢 Predicted:  the ruy lopez. 
🔴 Ground   : pins my knight [EOS]



________________________________________________________________________________

🟢 Predicted:  the knight comes back to e4, attacking the pawn on g3. 
🔴 Ground   : i give it to him setting up a trap of my own [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a good move for white. 
🔴 Ground   : placing knight on e5 means i need nc6 [EOS]



________________________________________________________________________________

🟢 Predicted:  i think this is a mistake, as it opens the h-file for my rook. 
🔴 Ground   : now i can let black unleash his bishop, since i will chop it off next move. [EOS]



________________________________________________________________________________

🟢 Predicted:  this is the main line of the philidor variation. 
🔴 Ground   : personaly i would say this is a very passive move, giving me both tempo and space. [EOS]



________________________________________________________________________________

🟢 Predicted:  this is the most common move in the sicilian dragon variation. 
🔴 Ground   : this doesn't seem consistent. the only meaningful idea behind nc6 i thought was to play into a 0-0-0 rather than a bc4 dragon with d5 where the loss of tempo may be made slightly relevant by the fact that i'm in an opening less consistent with my repertoire. [EOS]



________________________________________________________________________________

🟢 Predicted:  this is the start of a natural square for the knight. 
🔴 Ground   : this is just a weird move. it's very rare that black can be thinking about positional movement on the kingside in a dragon without the attack breaking through [EOS]



________________________________________________________________________________

🟢 Predicted:  white develops his knight to d2. 
🔴 Ground   : the q-side pieces bestir themselves at last. [EOS]



________________________________________________________________________________

🟢 Predicted:  the queen is trapped. 
🔴 Ground   : sometimes when the queen does claim the pawn, i might do an exchange such as qxb2-->rb1, and if the pawn on b7 has already moved forward, it provides an opportunity to trap the queen because she can no longer retreat. but in this case, the queen still can get to safety by going to b3 and then retreating to the back rank from there. so instead, my knight comes out to both protect the pawn and threaten the queen. [EOS]



________________________________________________________________________________



# Conclusion

There are many cases where the model make a great prediction - sometimes it's better than ground truth. But the model often makes mistakes and outputs wrong predictions. But overall, the model has definitely learned and is able to generate comments that are at least somewhat logical.