In [14]:
import dspy
import dotenv
import os
import chess
import mlflow
from typing import List

# Setup environment and LLM
dotenv.load_dotenv()
lm = dspy.LM('openai/gpt-4.1', api_key=os.getenv('OPENAI_API_KEY'), temperature=1.0, max_tokens=6000, cache=False)
dspy.configure(lm=lm)

# Set up MLflow for experiment tracking
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("DSPy-Chess-Puzzles-Blog")
mlflow.dspy.autolog()

In [15]:
# Load and filter the dataset to only puzzles with exactly 2 moves
from datasets import load_dataset

dataset = load_dataset(
    "Lichess/chess-puzzles",
    split="train",
    revision="8d4ee87",
).filter(lambda ex: len(ex['Moves'].split()) == 2)

# Shuffle and take a random sample of 1000
sampled = dataset.shuffle(seed=42).select(range(1000))

# Split into 80/20 train/test
split = sampled.train_test_split(test_size=0.2, seed=42)
train_set = split['train']
test_set = split['test']

In [16]:
class SolveChessPuzzle(dspy.Signature):
    """
    Given the opponent's last move, the board state (FEN), and legal 
    moves (UCI), reason step-by-step to determine the best move and 
    output it.
    """
    last_move: str = dspy.InputField(
        description="The opponent's most recent move"
    )
    board: str = dspy.InputField(
        description="The current state of the board in FEN notation"
    )
    legal_moves: List[str] = dspy.InputField(
        description="A list of legal moves in UCI notation"
    )
    reasoning: str = dspy.OutputField(
        description="A detailed explanation of the best move to play"
    )
    solution: str = dspy.OutputField(
        description="The best move to play in UCI notation (e.g., 'e2e4')"
    )

class ChessSolver(dspy.Module):
    def __init__(self):
        super().__init__()
        # Define the predictor using the ChainOfThought module and our Signature
        self.predictor = dspy.ChainOfThought(SolveChessPuzzle)
    def forward(self, last_move: str, board: str, legal_moves: List[str]):
        # Call the predictor's forward method
        prediction = self.predictor(
            last_move=last_move,
            board=board,
            legal_moves=legal_moves
        )
        return prediction
    
solver = ChessSolver()

In [17]:
# Function to preprocess the dataset into the correct format
def preprocess_dataset(dataset):
    examples = []
    
    for i, example in enumerate(dataset):
        # Extract FEN and moves
        fen = example['FEN']
        moves_str = example['Moves']
        moves = moves_str.split()
        
        if len(moves) != 2:
            continue  # Skip if not exactly 2 moves
            
        # Opponent's move is the first move
        opponent_move = moves[0]
        expected_move = moves[1]
        
        # Set up the chess board
        board = chess.Board(fen)
        
        # Apply opponent's move to get to the position where we need to find the solution
        try:
            board.push_uci(opponent_move)
            current_fen = board.fen()
            
            # Generate all legal moves as a list
            legal_moves = [move.uci() for move in board.legal_moves]
            
            if not legal_moves:
                continue  # Skip positions with no legal moves
                
            # Create a DSPy example
            dspy_example = dspy.Example(
                last_move=opponent_move,
                board=current_fen,
                legal_moves=legal_moves,
                expected_move=expected_move
            ).with_inputs('last_move', 'board', 'legal_moves')
            
            examples.append(dspy_example)
            
        except ValueError:
            # Skip invalid moves
            continue
    
    return examples

# Transform datasets
train_examples = preprocess_dataset(train_set)
test_examples = preprocess_dataset(test_set)

In [18]:
# Improved metric to validate the solution
def validate_chess_move(example, pred, trace=None):
    """
    Check if the predicted solution matches the expected move.
    
    Args:
        example: The ground truth example
        pred: The prediction from the model
        trace: Optional trace information
        
    Returns:
        1.0 if the prediction is correct, 0.0 otherwise
    """
    # Convert expected move to lowercase for case-insensitive comparison
    expected = example.expected_move.lower()
    
    # Extract the predicted move from the solution field
    predicted_text = pred.solution.lower()
    
    # Clean up any potential notation extras
    predicted = predicted_text
    
    # Remove 'x' for captures if present
    predicted = predicted.replace('x', '')
    
    # Remove any check or checkmate symbols
    predicted = predicted.replace('+', '').replace('#', '')
    
    # If it's still not a valid UCI move, look through tokens
    if len(predicted) != 4:
        # Try to find a valid UCI move in the prediction
        for potential_move in predicted_text.split():
            # Remove captures, checks, etc.
            clean_move = potential_move.replace('x', '').replace('+', '').replace('#', '')
            if len(clean_move) == 4 and all(c in 'abcdefgh12345678' for c in clean_move):
                predicted = clean_move
                break
    
    # Perform comparison
    is_correct = expected == predicted
    
    if trace is not None:
        return is_correct
    
    return float(is_correct)

In [19]:
# Set up the evaluator
THREADS = 4 
evaluator = dspy.Evaluate(
    devset=test_examples,
    metric=validate_chess_move,
    num_threads=THREADS,
    display_progress=True,
    display_table=5
)

In [20]:
# Run evaluation on the baseline model
with mlflow.start_run(run_name="Chess-Solver-Baseline", tags={"model": f"{lm.model}"}):
    results = evaluator(solver)
    
    # Log accuracy as a metric
    mlflow.log_metric("accuracy", results)
    
    # Print results
    print(f"Baseline Accuracy: {results}")

  0%|          | 0/200 [00:00<?, ?it/s]

Average Metric: 45.00 / 200 (22.5%): 100%|██████████| 200/200 [08:19<00:00,  2.50s/it]

2025/05/18 17:09:28 INFO dspy.evaluate.evaluate: Average Metric: 45.0 / 200 (22.5%)





Unnamed: 0,last_move,board,legal_moves,expected_move,reasoning,solution,validate_chess_move
0,d3b3,8/1pp3pk/p5bp/4QN2/2P1P1P1/1q5P/1P3K2/8 w - - 0 32,"['f5g7', 'f5e7', 'f5h6', 'f5d6', 'f5h4', 'f5d4', 'f5g3', 'f5e3', '...",e5g7,Black is threatening with a passed pawn on the queenside and a dan...,f5g7,
1,f7f6,r1bqkb1r/pp1pn1p1/2n1pp1p/2p5/2P1N1P1/4P3/PP1P1PBP/R1BQK1NR w KQkq...,"['e4f6', 'e4d6', 'e4g5', 'e4c5', 'e4g3', 'e4c3', 'g2h3', 'g2f3', '...",e4d6,"Black has just played ...f6, weakening their kingside and allowing...",e4f6,
2,e7e6,rn1q1bkr/pp4pp/2p1p3/8/3PP1n1/1QN5/PP3PPP/R1B1K2R w KQ - 0 11,"['c3d5', 'c3b5', 'c3a4', 'c3e2', 'c3d1', 'c3b1', 'b3b7', 'b3e6', '...",b3e6,Black's last move e7e6 opens up the light-squared bishop but also ...,c3d5,
3,e7f7,4r1k1/pbQ2Rb1/1p6/1Pn5/P1P5/8/5PPP/6K1 b - - 0 27,"['g8h8', 'g8h7', 'e8f8', 'e8d8', 'e8c8', 'e8b8', 'e8a8', 'e8e7', '...",e8e1,"White just played Qxf7+, placing the Black king in check. The boar...",g8h8,
4,a1f1,5r2/pp4k1/2npB1p1/4p3/1P4Q1/2PP4/6PP/5R1K b - - 0 23,"['f8h8', 'f8g8', 'f8e8', 'f8d8', 'f8c8', 'f8b8', 'f8a8', 'f8f7', '...",f8f1,Let's break down the current position and the last move: - Last mo...,g6g5,


Baseline Accuracy: 22.5
🏃 View run Chess-Solver-Baseline at: http://127.0.0.1:5000/#/experiments/995750488067348082/runs/d896e64e12d641109e833e47535135c0
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/995750488067348082


In [21]:
def prepare_examples_for_labeled_few_shot(original_examples):
    """
    Transforms examples from the original format to match the signature fields 
    of SolveChessPuzzle for use with LabeledFewShot optimizer.
    
    Args:
        original_examples: List of examples from preprocess_dataset
        
    Returns:
        List of examples with fields that match the SolveChessPuzzle signature
    """
    transformed_examples = []
    
    for ex in original_examples:
        # Create a new example with the correctly named output fields
        transformed_ex = dspy.Example(
            # Input fields - keep the same
            last_move=ex.last_move,
            board=ex.board,
            legal_moves=ex.legal_moves,
            
            # Output fields - map to match signature
            solution=ex.expected_move,
            reasoning="Find the best move in this position."
        ).with_inputs('last_move', 'board', 'legal_moves')
        
        transformed_examples.append(transformed_ex)
    
    return transformed_examples

In [22]:
from dspy.teleprompt import LabeledFewShot

# Transform examples to work with LabeledFewShot
train_examples_for_lfs = prepare_examples_for_labeled_few_shot(train_examples)

# Initialize the optimizer with k=8 examples
labeled_fewshot_optimizer = LabeledFewShot(k=8)

# Use the optimizer to compile an optimized version of our solver
optimized_solver = labeled_fewshot_optimizer.compile(
    student=solver,
    trainset=train_examples_for_lfs
)

In [23]:
# Run evaluation on the optimized model
with mlflow.start_run(run_name="Chess-Solver-LabeledFewShot-Optimized", tags={"model": f"{lm.model}"}):
    # Note: Keeping your original validation function unchanged
    optimized_results = evaluator(optimized_solver)
    
    # Log accuracy as a metric
    mlflow.log_metric("accuracy", optimized_results)
    
    # Print results
    print(f"Optimized Accuracy: {optimized_results}")
    
    # Print improvement
    print(f"Improvement: {optimized_results - results}")

Average Metric: 35.00 / 101 (34.7%):  50%|█████     | 101/200 [10:41<10:12,  6.19s/it]



Average Metric: 75.00 / 200 (37.5%): 100%|██████████| 200/200 [22:09<00:00,  6.65s/it]

2025/05/18 17:31:37 INFO dspy.evaluate.evaluate: Average Metric: 75.0 / 200 (37.5%)





Unnamed: 0,last_move,board,legal_moves,expected_move,reasoning,solution,validate_chess_move
0,d3b3,8/1pp3pk/p5bp/4QN2/2P1P1P1/1q5P/1P3K2/8 w - - 0 32,"['f5g7', 'f5e7', 'f5h6', 'f5d6', 'f5h4', 'f5d4', 'f5g3', 'f5e3', '...",e5g7,"This position is highly tactical, and Black just played ...Qxb3, s...",f5g7,
1,f7f6,r1bqkb1r/pp1pn1p1/2n1pp1p/2p5/2P1N1P1/4P3/PP1P1PBP/R1BQK1NR w KQkq...,"['e4f6', 'e4d6', 'e4g5', 'e4c5', 'e4g3', 'e4c3', 'g2h3', 'g2f3', '...",e4d6,"The opponent has just played ...f7-f6, blocking their dark-square ...",e4f6,
2,e7e6,rn1q1bkr/pp4pp/2p1p3/8/3PP1n1/1QN5/PP3PPP/R1B1K2R w KQ - 0 11,"['c3d5', 'c3b5', 'c3a4', 'c3e2', 'c3d1', 'c3b1', 'b3b7', 'b3e6', '...",b3e6,"The position is fairly open, and both sides still have all their p...",d4d5,
3,e7f7,4r1k1/pbQ2Rb1/1p6/1Pn5/P1P5/8/5PPP/6K1 b - - 0 27,"['g8h8', 'g8h7', 'e8f8', 'e8d8', 'e8c8', 'e8b8', 'e8a8', 'e8e7', '...",e8e1,"In this position, White is threatening checkmate with Rxg7+ and Qx...",e8e7,
4,a1f1,5r2/pp4k1/2npB1p1/4p3/1P4Q1/2PP4/6PP/5R1K b - - 0 23,"['f8h8', 'f8g8', 'f8e8', 'f8d8', 'f8c8', 'f8b8', 'f8a8', 'f8f7', '...",f8f1,The White queen on g4 and bishop on e6 are exerting strong pressur...,f8f1,✔️ [1.000]


Optimized Accuracy: 37.5
Improvement: 15.0
🏃 View run Chess-Solver-LabeledFewShot-Optimized at: http://127.0.0.1:5000/#/experiments/995750488067348082/runs/49912045edd34511a7b7afb327f5ee48
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/995750488067348082


In [24]:
from dspy.teleprompt import BootstrapFewShot

# Set up the BootstrapFewShot optimizer
bootstrap_fewshot_optimizer = BootstrapFewShot(
    metric=validate_chess_move,
    max_bootstrapped_demos=4,   # how many demos to generate
    max_labeled_demos=16,       # how many from the labeled dataset to allow
    max_rounds=1,               # number of bootstrapping rounds (you can increase this)
    max_errors=5                # max validation failures before early stopping
)

# Compile an optimized solver using BootstrapFewShot
bootstrap_optimized_solver = bootstrap_fewshot_optimizer.compile(
    student=solver,
    trainset=train_examples
)

  2%|▏         | 15/800 [02:26<2:08:02,  9.79s/it]

Bootstrapped 4 full traces after 15 examples for up to 1 rounds, amounting to 15 attempts.





In [25]:
# Evaluate the BootstrapFewShot-optimized solver
with mlflow.start_run(run_name="Chess-Solver-BootstrapFewShot-Optimized", tags={"model": f"{lm.model}"}):
    bootstrap_results = evaluator(bootstrap_optimized_solver)
    
    # Log accuracy
    mlflow.log_metric("accuracy", bootstrap_results)
    
    # Print results
    print(f"BootstrapFewShot Optimized Accuracy: {bootstrap_results}")
    print(f"Improvement over baseline: {bootstrap_results - results}")

Average Metric: 82.00 / 200 (41.0%): 100%|██████████| 200/200 [20:08<00:00,  6.04s/it]

2025/05/18 17:54:13 INFO dspy.evaluate.evaluate: Average Metric: 82.0 / 200 (41.0%)





Unnamed: 0,last_move,board,legal_moves,expected_move,reasoning,solution,validate_chess_move
0,d3b3,8/1pp3pk/p5bp/4QN2/2P1P1P1/1q5P/1P3K2/8 w - - 0 32,"['f5g7', 'f5e7', 'f5h6', 'f5d6', 'f5h4', 'f5d4', 'f5g3', 'f5e3', '...",e5g7,"Black has just played ...Qxb3 (d3b3), capturing on b3 with the que...",e5g7,✔️ [1.000]
1,f7f6,r1bqkb1r/pp1pn1p1/2n1pp1p/2p5/2P1N1P1/4P3/PP1P1PBP/R1BQK1NR w KQkq...,"['e4f6', 'e4d6', 'e4g5', 'e4c5', 'e4g3', 'e4c3', 'g2h3', 'g2f3', '...",e4d6,Black's last move f7-f6 weakens the light squares around their kin...,e4f6,
2,e7e6,rn1q1bkr/pp4pp/2p1p3/8/3PP1n1/1QN5/PP3PPP/R1B1K2R w KQ - 0 11,"['c3d5', 'c3b5', 'c3a4', 'c3e2', 'c3d1', 'c3b1', 'b3b7', 'b3e6', '...",b3e6,"Black has just played 11...e6, attacking the White queen and cutti...",b3b7,
3,e7f7,4r1k1/pbQ2Rb1/1p6/1Pn5/P1P5/8/5PPP/6K1 b - - 0 27,"['g8h8', 'g8h7', 'e8f8', 'e8d8', 'e8c8', 'e8b8', 'e8a8', 'e8e7', '...",e8e1,"White's last move was Qc7-f7+, delivering a check to the Black kin...",g8h8,
4,a1f1,5r2/pp4k1/2npB1p1/4p3/1P4Q1/2PP4/6PP/5R1K b - - 0 23,"['f8h8', 'f8g8', 'f8e8', 'f8d8', 'f8c8', 'f8b8', 'f8a8', 'f8f7', '...",f8f1,"White's last move was Ra1-f1, doubling up on the f-file and possib...",f8f1,✔️ [1.000]


BootstrapFewShot Optimized Accuracy: 41.0
Improvement over baseline: 18.5
🏃 View run Chess-Solver-BootstrapFewShot-Optimized at: http://127.0.0.1:5000/#/experiments/995750488067348082/runs/af858f386bd84621822c31b2ceb593ac
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/995750488067348082
