In [1]:
import random
import itertools

import chat as chat
import data as data
import utility as util

In [2]:
# Import our data -- the load_challenge_moves_csv function also processes the csv load so move / win probability are lists rather than strings
train_df = data.loader.load_challenge_moves_csv("data/chess_challenges_train_10k.csv", shuffle=True)

train_iterator = itertools.cycle(train_df.iterrows())

In [None]:
def evaluate_chess_model(ollama_session, train_iterator, system_prompt, board_representation, move_representation, with_piece, max_iters=None, max_timeout=30, verbose=False):
    evaluation_results = {
        "num_attempts": 0,
        "legal_move_ranks": [],
        "num_legal_moves": 0,
        "error_illegal_move": 0,
        "error_timeout": 0,
        "error_generation": 0,
        "error_extraction": 0,
    }

    for iter in range(max_iters):
        evaluation_results["num_attempts"] += 1
        try:
            _, row = next(train_iterator)
            board = row["FEN"]
            legal_moves = row["Move"]

            piece_letter, piece_position = util.get_random_piece_and_position(board)

            legal_piece_moves = util.get_legal_moves(board, piece_position, move_representation)

            #pass proper argumements: board, board_representation, move_type...
            # if with piece send piece_letter
            prompt = chat.format_prompt_for_legal_move(board, board_representation, piece_letter, piece_position, move_representation)
            
            # Generate a response from the model
            response, runtime_results = ollama_session.chat_baseline(system_prompt = system_prompt, user_prompt=prompt, timeout=max_timeout)

            if verbose:
                print(f"{'-'*100}\nPrompt:\n{prompt}\n\nResponse:\n{response}\n\nRuntime Results:\n{runtime_results}\n{'-'*100}\n")
                util.visualize_board_ipynb(row["FEN"])

            moves = chat.extract_legal_moves(response)

            correct_moves, illegal_moves = util.compare_moves_and_legal_moves(moves, legal_moves)
            
            #update evaluation results

            
            # If move in legal moves print probability / move ranking
            # evaluation_results["num_legal_moves"] += 1
            tps = runtime_results["generated_tokens"] / runtime_results["generation_duration"]
            print(f"[{iter+1:<4}/{max_iters:<4}] Move: {moves} | Correct Ratio: {len(correct_moves)} / {len(legal_moves)} | Illegal Raio: {len(illegal_moves)} / {len(moves)} | TPS: {tps:.2f}")
            # evaluation_results['legal_move_ranks'].append((move_idx+1)/len(legal_moves))

        except Exception as e:
            print(f"[{iter+1:<4}/{max_iters:<4}] {type(e).__name__}: {e}")
            if type(e) == chat.IllegalMoveError:
                evaluation_results["error_illegal_move"] += 1
            elif type(e) == chat.TimeoutError:
                evaluation_results["error_timeout"] += 1
            elif type(e) == chat.GenerationError:
                evaluation_results["error_generation"] += 1
            elif type(e) == chat.ExtractionError:
                evaluation_results["error_extraction"] += 1
            else:
                print(f"Unknown Error: {e}")

    # At end print out results
    avg_rank = sum(evaluation_results['legal_move_ranks'])/len(evaluation_results['legal_move_ranks']) if len(evaluation_results['legal_move_ranks']) else 0
    print(f"\n{'='*60}\nAverage Legal Move Score (Rank / Total Moves):\n{avg_rank:.4f}\n")
    print(f"Evaluation Results:\n")
    for key, value in evaluation_results.items():
        if key != "legal_move_ranks":
            print(f"{key}: {value}")

In [5]:
# Create a new Ollama Session to allow us to chat w/ various models
# IMPORTANT: Make sure to call `ollama serve` in your terminal to start the Ollama server
# [Lucas]: I'm personally getting ~70TPS on 1.5b and ~7TPS on 7b on my laptop. Most responses are between 1000-2000 tokens.
model_name = "deepseek-r1:1.5b"       # {deepseek-r1:1.5b, deepseek-r1:7b}
board_rep = "grid" # {grid, desc, FEN}
ollama_session = chat.OllamaSession(model=model_name, use_cuda=False, board_representation= board_rep)

evaluate_chess_model(
    ollama_session = ollama_session, 
    train_iterator = train_iterator, 
    board_representation = board_rep,
    max_iters = 1,
    max_timeout = 200,
    verbose = True
)

[1   /1   ] KeyError: 'message'
Unknown Error: 'message'

Average Legal Move Score (Rank / Total Moves):
0.0000

Evaluation Results:

num_attempts: 1
num_legal_moves: 0
error_illegal_move: 0
error_timeout: 0
error_generation: 0
error_extraction: 0
