In [1]:
import dspy
import dotenv
import os
import chess
import mlflow
from typing import List

# Setup environment and LLM
dotenv.load_dotenv()
lm = dspy.LM('openai/gpt-4.1', api_key=os.getenv('OPENAI_API_KEY'), temperature=1.0, max_tokens=6000, cache=True)
dspy.configure(lm=lm)

# Set up MLflow for experiment tracking
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("DSPy-Chess-Puzzles-Blog")
mlflow.dspy.autolog()

In [2]:
# Load and filter the dataset to only puzzles with exactly 2 moves
from datasets import load_dataset
dataset = load_dataset("Lichess/chess-puzzles", split="train").filter(
    lambda ex: len(ex['Moves'].split()) == 2
)

# Shuffle and take a random sample of 1000
sampled = dataset.shuffle(seed=42).select(range(1000))

# Split into 80/20 train/test
split = sampled.train_test_split(test_size=0.2, seed=42)
train_set = split['train']
test_set = split['test']

In [None]:
class SolveChessPuzzle(dspy.Signature):
    last_move: str = dspy.InputField(
        description="The opponent's most recent move"
    )
    board: str = dspy.InputField(
        description="The current state of the board in FEN notation"
    )
    legal_moves: List[str] = dspy.InputField(
        description="A list of legal moves in UCI notation"
    )
    reasoning: str = dspy.OutputField(
        description="A detailed explanation of the best move to play"
    )
    solution: str = dspy.OutputField(
        description="The best move to play in UCI notation (e.g., 'e2e4')"
    )

    
class ChessSolver(dspy.Module):
    def __init__(self):
        super().__init__()
        # Define the predictor using the ChainOfThought module and our Signature
        self.predictor = dspy.ChainOfThought(SolveChessPuzzle)

    def forward(self, last_move: str, board: str, legal_moves: List[str]):
        # Call the predictor's forward method
        prediction = self.predictor(
            last_move=last_move,
            board=board,
            legal_moves=legal_moves
        )
        return prediction

# Initialize our solver
solver = ChessSolver()

In [4]:
# Function to preprocess the dataset into the correct format
def preprocess_dataset(dataset):
    examples = []
    
    for i, example in enumerate(dataset):
        # Extract FEN and moves
        fen = example['FEN']
        moves_str = example['Moves']
        moves = moves_str.split()
        
        if len(moves) != 2:
            continue  # Skip if not exactly 2 moves
            
        # Opponent's move is the first move
        opponent_move = moves[0]
        expected_move = moves[1]
        
        # Set up the chess board
        board = chess.Board(fen)
        
        # Apply opponent's move to get to the position where we need to find the solution
        try:
            board.push_uci(opponent_move)
            current_fen = board.fen()
            
            # Generate all legal moves as a list
            legal_moves = [move.uci() for move in board.legal_moves]
            
            if not legal_moves:
                continue  # Skip positions with no legal moves
                
            # Create a DSPy example
            dspy_example = dspy.Example(
                last_move=opponent_move,
                board=current_fen,
                legal_moves=legal_moves,
                expected_move=expected_move
            ).with_inputs('last_move', 'board', 'legal_moves')
            
            examples.append(dspy_example)
            
        except ValueError:
            # Skip invalid moves
            continue
    
    return examples

# Transform datasets
train_examples = preprocess_dataset(train_set)
test_examples = preprocess_dataset(test_set)

In [5]:
# Improved metric to validate the solution
def validate_chess_move(example, pred, trace=None):
    """
    Check if the predicted solution matches the expected move.
    
    Args:
        example: The ground truth example
        pred: The prediction from the model
        trace: Optional trace information
        
    Returns:
        1.0 if the prediction is correct, 0.0 otherwise
    """
    # Convert expected move to lowercase for case-insensitive comparison
    expected = example.expected_move.lower()
    
    # Extract the predicted move from the solution field
    predicted_text = pred.solution.lower()
    
    # Clean up any potential notation extras
    predicted = predicted_text
    
    # Remove 'x' for captures if present
    predicted = predicted.replace('x', '')
    
    # Remove any check or checkmate symbols
    predicted = predicted.replace('+', '').replace('#', '')
    
    # If it's still not a valid UCI move, look through tokens
    if len(predicted) != 4:
        # Try to find a valid UCI move in the prediction
        for potential_move in predicted_text.split():
            # Remove captures, checks, etc.
            clean_move = potential_move.replace('x', '').replace('+', '').replace('#', '')
            if len(clean_move) == 4 and all(c in 'abcdefgh12345678' for c in clean_move):
                predicted = clean_move
                break
    
    # Perform comparison
    is_correct = expected == predicted
    
    if trace is not None:
        return is_correct
    
    return float(is_correct)

In [6]:
# Set up the evaluator
THREADS = 4 
evaluator = dspy.Evaluate(
    devset=test_examples,
    metric=validate_chess_move,
    num_threads=THREADS,
    display_progress=True,
    display_table=5
)

In [7]:
# Run evaluation on the baseline model
with mlflow.start_run(run_name="Chess-Solver-Baseline", tags={"model": f"{lm.model}"}):
    results = evaluator(solver)
    
    # Log accuracy as a metric
    mlflow.log_metric("accuracy", results)
    
    # Print results
    print(f"Baseline Accuracy: {results}")

Average Metric: 58.00 / 200 (29.0%): 100%|██████████| 200/200 [00:02<00:00, 88.65it/s]

2025/04/24 23:48:02 INFO dspy.evaluate.evaluate: Average Metric: 58.0 / 200 (29.0%)





Unnamed: 0,last_move,board,legal_moves,expected_move,reasoning,solution,validate_chess_move
0,d3b3,8/1pp3pk/p5bp/4QN2/2P1P1P1/1q5P/1P3K2/8 w - - 0 32,"['f5g7', 'f5e7', 'f5h6', 'f5d6', 'f5h4', 'f5d4', 'f5g3', 'f5e3', '...",e5g7,The black queen on b3 and dark-squared bishop on g6 are both aimin...,e5g7,✔️ [1.000]
1,f7f6,r1bqkb1r/pp1pn1p1/2n1pp1p/2p5/2P1N1P1/4P3/PP1P1PBP/R1BQK1NR w KQkq...,"['e4f6', 'e4d6', 'e4g5', 'e4c5', 'e4g3', 'e4c3', 'g2h3', 'g2f3', '...",e4d6,"After Black's last move (f7f6), the f6 pawn is undefended and Whit...",e4f6,
2,e7e6,rn1q1bkr/pp4pp/2p1p3/8/3PP1n1/1QN5/PP3PPP/R1B1K2R w KQ - 0 11,"['c3d5', 'c3b5', 'c3a4', 'c3e2', 'c3d1', 'c3b1', 'b3b7', 'b3e6', '...",b3e6,"Black has just played ...e6, reinforcing the d5 square and opening...",c3d5,
3,e7f7,4r1k1/pbQ2Rb1/1p6/1Pn5/P1P5/8/5PPP/6K1 b - - 0 27,"['g8h8', 'g8h7', 'e8f8', 'e8d8', 'e8c8', 'e8b8', 'e8a8', 'e8e7', '...",e8e1,"White has just played Qc7-f7, threatening mate on g7. The white qu...",g7f8,
4,a1f1,5r2/pp4k1/2npB1p1/4p3/1P4Q1/2PP4/6PP/5R1K b - - 0 23,"['f8h8', 'f8g8', 'f8e8', 'f8d8', 'f8c8', 'f8b8', 'f8a8', 'f8f7', '...",f8f1,"White's last move was Rf1, doubling up pressure against Black's f-...",f8h8,


Baseline Accuracy: 29.0
🏃 View run Chess-Solver-Baseline at: http://127.0.0.1:5000/#/experiments/995750488067348082/runs/4d81a3ff7c26420289babfc083472f2e
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/995750488067348082
