In [None]:
#data
#https://www.kaggle.com/datasets/milesh1/35-million-chess-games?resource=download

In [1]:
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from transformers import BertTokenizer, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset
from transformers import BertForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm
E0000 00:00:1733095010.735303      13 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:479
E1201 23:16:50.774956573      13 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-12-01T23:16:50.774939574+00:00", grpc_status:2}


In [2]:
# Step 1: Dataset Parsing and Preprocessing
def parse_dataset(file_path):
    """
    Parse the dataset to extract metadata and moves.
    """
    games = []
    count =0
    with open(file_path, 'r') as f:
        for line in f:
            if "###" in line:
                parts = line.strip().split("###")
                metadata = parts[0].split()
                moves = parts[1].strip()
                if count > 1000:
                    break
                if count%100==0:
                    print(count)
                count=count+1
                # Extract metadata
                game_id = int(metadata[0])
                date = metadata[1]
                result = metadata[2]
                white_elo = int(metadata[3]) if metadata[3] != "None" else None
                black_elo = int(metadata[4]) if metadata[4] != "None" else None
                num_moves = int(metadata[5])
                
                # Append game data
                games.append({
                    "game_id": game_id,
                    "date": date,
                    "result": result,
                    "white_elo": white_elo,
                    "black_elo": black_elo,
                    "num_moves": num_moves,
                    "moves": moves
                })
    return pd.DataFrame(games)

def preprocess_moves(move_sequence):
    """
    Extract moves from the sequence.
    """
    moves = re.findall(r'[WB]\d+\.[a-hRNBQKxO-]+\+?#?', move_sequence)
    return [move[3:] for move in moves]

def create_sequences(moves, result):
    """
    Create input-output pairs for model training.
    """
    result_label = 1 if result == "1-0" else 0 if result == "0-1" else 2  # White win, Black win, Draw
    sequences = []
    for i in range(1, len(moves)):
        input_seq = moves[:i]
        next_move = moves[i]
        sequences.append((" ".join(input_seq), next_move, result_label))
    return sequences

# Load dataset
file_path = "/kaggle/input/chessgptdata/all_with_filtered_anotations_since1998.txt"
df = parse_dataset(file_path)

# Preprocess move sequences
df['moves'] = df['moves'].apply(preprocess_moves)

# Generate input-output pairs
sequences = []
for _, row in df.iterrows():
    sequences.extend(create_sequences(row['moves'], row['result']))

seq_df = pd.DataFrame(sequences, columns=["input_sequence", "next_move", "result_label"])

# Save processed data
seq_df.to_csv("processed_chess_data.csv", index=False)



0
100
200
300
400
500
600
700
800
900
1000


In [3]:
# Step 2: Dataset Class for PyTorch
class ChessDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data.iloc[index]
        input_seq = item['input_sequence']
        next_move = item['next_move']  # Keep as a string
        result_label = int(item['result_label'])
    
        # Tokenize input sequence
        encoding = self.tokenizer(
            input_seq,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )
    
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': self.tokenizer.convert_tokens_to_ids(next_move),  # Convert move to token ID
            'result_label': torch.tensor(result_label, dtype=torch.long)
        }



In [4]:

# Step 3: Train-Test Split
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LEN = 50

train_data, val_data = train_test_split(seq_df, test_size=0.1, random_state=42)
train_dataset = ChessDataset(train_data, tokenizer, MAX_LEN)
val_dataset = ChessDataset(val_data, tokenizer, MAX_LEN)

# Step 4: Custom Model
class ChessPredictionModel(torch.nn.Module):
    def __init__(self, num_moves, num_results):
        super(ChessPredictionModel, self).__init__()
        self.bert = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased',
            num_labels=num_moves,
            output_hidden_states=True  # Enable hidden states
        )
        self.result_classifier = torch.nn.Linear(768, num_results)
        self.loss_fn = torch.nn.CrossEntropyLoss()  # Use CrossEntropyLoss for classification tasks

    def forward(self, input_ids, attention_mask, labels=None, result_label=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        logits_moves = bert_output.logits

        # Ensure hidden states are available
        if bert_output.hidden_states is not None:
            pooled_output = bert_output.hidden_states[-1][:, 0]
            logits_result = self.result_classifier(pooled_output)
        else:
            raise ValueError("Hidden states are not available in the BERT output.")

        output = {
            "logits_moves": logits_moves,
            "logits_result": logits_result,
        }

        if labels is not None and result_label is not None:
            # Compute the loss for both predictions
            move_loss = self.loss_fn(logits_moves, labels)
            result_loss = self.loss_fn(logits_result, result_label)
            total_loss = move_loss + result_loss
            output["loss"] = total_loss

        return output



# Instantiate the model
num_moves = len(seq_df['next_move'].unique())
num_results = 3  # White win, Black win, Draw
model = ChessPredictionModel(num_moves=num_moves, num_results=num_results)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# Step 5: Training with Checkpointing
training_args = TrainingArguments(
    output_dir='./checkpoints',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=50,
    evaluation_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    learning_rate=5e-5,
    load_best_model_at_end=True
)

# Define evaluation metrics
def compute_metrics(pred):
    logits_moves, logits_result = pred.predictions
    labels_moves, labels_result = pred.label_ids
    move_preds = torch.argmax(torch.tensor(logits_moves), dim=-1)
    result_preds = torch.argmax(torch.tensor(logits_result), dim=-1)

    move_accuracy = accuracy_score(labels_moves, move_preds.numpy())
    result_accuracy = accuracy_score(labels_result, result_preds.numpy())
    return {"move_accuracy": move_accuracy, "result_accuracy": result_accuracy}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train with checkpointing
trainer.train(resume_from_checkpoint=False)


E0000 00:00:1733095039.257597      13 common_lib.cc:818] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:483


Step,Training Loss,Validation Loss


In [None]:
# Step 6: Save Model
model.save_pretrained("./chess_model")
tokenizer.save_pretrained("./chess_model")

# Step 7: Evaluation
eval_results = trainer.evaluate()
print("Evaluation Results:", eval_results)

# Step 8: Test Prediction
def predict_next_move(model, tokenizer, input_moves):
    model.eval()
    input_ids = tokenizer(
        input_moves,
        return_tensors='pt',
        padding='max_length',
        truncation=True,
        max_length=MAX_LEN
    )['input_ids']

    with torch.no_grad():
        logits_moves, logits_result = model(input_ids)
        predicted_move = torch.argmax(logits_moves, dim=-1).item()
        predicted_result = torch.argmax(logits_result, dim=-1).item()

    return predicted_move, predicted_result

In [None]:

# Test Example
test_moves = "W1.e4 B1.e5 W2.Nf3"
predicted_move, predicted_result = predict_next_move(model, tokenizer, test_moves)
print(f"Predicted Move: {predicted_move}, Predicted Result: {predicted_result}")
