In [1]:
import random
import itertools

import chat as chat
import data as data
import utility as util

In [2]:
# Import our data -- the load_challenge_moves_csv function also processes the csv load so move / win probability are lists rather than strings
train_df = data.loader.load_challenge_moves_csv("data/chess_challenges_train_10k.csv", shuffle=True)

train_iterator = itertools.cycle(train_df.iterrows())

In [None]:
def evaluate_chess_model(ollama_session, train_iterator, board_representation, move_representation, with_piece, max_iters=None, max_timeout=30, verbose=False):
    evaluation_results = {
        "num_attempts": 0,
        "num_predicted": [],
        "num_actual": [],
        "num_correct": [],
        "num_illegal": [],
        "num_missed": [],
        "precision": [],
        "recall": [],
        "f1_score": [],
        "error_timeout": 0,
        "error_generation": 0,
        "error_extraction": 0,
    }
    system_prompt = chat.generate_chess_system_prompt(board_representation, move_representation, with_piece)
    # print("SYSTEM PROMPT")
    # print(system_prompt)
    for iter in range(max_iters):
        evaluation_results["num_attempts"] += 1
        try:
            _, row = next(train_iterator)
            board = row["FEN"]
            # util.visualize_board_ipynb(row["FEN"])
            # print("BOARD", board)
            piece_letter, piece_position = util.get_random_piece_and_position(board)
            # print("PIECE")
            # print(piece_letter, piece_position)
            legal_moves = util.get_legal_moves(board, piece_position, move_representation)
            # print("LEGAL MOVES")
            # print(legal_moves)
            prompt = chat.format_prompt_for_legal_move(board, board_representation, piece_letter, piece_position, move_representation)
            # print("PROMPT")
            # print(prompt)
            # Generate a response from the model
            response, runtime_results = ollama_session.chat_baseline(system_prompt = system_prompt, user_prompt=prompt, timeout=max_timeout)
            # print("RESPONSE")
            # print(response)
            moves = chat.extract_legal_moves(response)
            # print("MOVES")
            # print(moves)
            if verbose:
                print(f"{'-'*100}\nPrompt:\n{prompt}\n\nResponse:\n{response}\n\nRuntime Results:\n{runtime_results}\n{'-'*100}\n")
                # util.visualize_board_ipynb(row["FEN"])
                print(moves)

            correct_moves, missed_moves, illegal_moves = util.compare_moves_and_legal_moves(moves, legal_moves)
            
            #update evaluation results
            num_predicted = len(moves)
            num_actual = len(legal_moves)
            num_correct = len(correct_moves)
            num_illegal = len(illegal_moves)
            num_missed = len(missed_moves)
            precision = num_correct / num_predicted if num_predicted > 0 else 0
            recall = num_correct / num_actual if num_actual > 0 else 0
            f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
            
            evaluation_results["num_predicted"].append(num_predicted)
            evaluation_results["num_actual"].append(num_actual)
            evaluation_results["num_correct"].append(num_correct)
            evaluation_results["num_illegal"].append(num_illegal)
            evaluation_results["num_missed"].append(num_missed)
            evaluation_results["precision"].append(precision)
            evaluation_results["recall"].append(recall)
            evaluation_results["f1_score"].append(f1_score)
            tps = runtime_results["generated_tokens"] / runtime_results["generation_duration"]
            print(f"[{iter+1:<4}/{max_iters:<4}] Move: {moves} | Correct Ratio: {len(correct_moves)} / {len(legal_moves)} | Illegal Raio: {len(illegal_moves)} / {len(moves)} | TPS: {tps:.2f}")

        except Exception as e:
            print(f"[{iter+1:<4}/{max_iters:<4}] {type(e).__name__}: {e}")
            if type(e) == chat.TimeoutError:
                evaluation_results["error_timeout"] += 1
            elif type(e) == chat.GenerationError:
                evaluation_results["error_generation"] += 1
            elif type(e) == chat.ExtractionError:
                evaluation_results["error_extraction"] += 1
            else:
                print(f"Unknown Error: {e}")

    # At end print out results
    # Compute aggregate statistics
    total_attempts = evaluation_results["num_attempts"]
    avg_precision = sum(evaluation_results["precision"]) / total_attempts
    avg_recall = sum(evaluation_results["recall"]) / total_attempts
    avg_f1_score = sum(evaluation_results["f1_score"]) / total_attempts
    # overall_accuracy = sum(evaluation_results["num_correct"]) / sum(evaluation_results["num_actual"])

    print(f"\n{'='*60}\nEvaluation Summary:")
    print(f"Total Attempts: {total_attempts}")
    print(f"Avg Precision: {avg_precision:.4f}")
    print(f"Avg Recall: {avg_recall:.4f}")
    print(f"Avg F1 Score: {avg_f1_score:.4f}")
    # print(f"Overall Accuracy: {overall_accuracy:.4f}")

    return evaluation_results

In [4]:
# Create a new Ollama Session to allow us to chat w/ various models
# IMPORTANT: Make sure to call `ollama serve` in your terminal to start the Ollama server
model_name = "deepseek-r1:8b"       # {deepseek-r1:1.5b, deepseek-r1:7b}

board_types = ["FEN", "FEN_spaces", "FEN_dots", "FEN_spaced_dots", "desc", "grid"]
move_representations = ["PGN", "UCI"]

for board_representation in board_types:
    
    for move_rep in move_representations:
        ollama_session = chat.OllamaSession(model=model_name, use_cuda=False, board_representation=board_representation)
        
        results = evaluate_chess_model(
            ollama_session=ollama_session,
            train_iterator=train_iterator,
            board_representation=board_representation,
            move_representation=move_rep,
            with_piece=True,
            max_iters=1,
            max_timeout=120,
            verbose=True
        )

[1   /1   ] TimeoutError: The chat request exceeded the timeout limit (120 seconds).

Evaluation Summary:
Total Attempts: 1
Avg Precision: 0.0000
Avg Recall: 0.0000
Avg F1 Score: 0.0000
[1   /1   ] TimeoutError: The chat request exceeded the timeout limit (120 seconds).

Evaluation Summary:
Total Attempts: 1
Avg Precision: 0.0000
Avg Recall: 0.0000
Avg F1 Score: 0.0000
[1   /1   ] TimeoutError: The chat request exceeded the timeout limit (120 seconds).

Evaluation Summary:
Total Attempts: 1
Avg Precision: 0.0000
Avg Recall: 0.0000
Avg F1 Score: 0.0000
[1   /1   ] TimeoutError: The chat request exceeded the timeout limit (120 seconds).

Evaluation Summary:
Total Attempts: 1
Avg Precision: 0.0000
Avg Recall: 0.0000
Avg F1 Score: 0.0000
[1   /1   ] TimeoutError: The chat request exceeded the timeout limit (120 seconds).

Evaluation Summary:
Total Attempts: 1
Avg Precision: 0.0000
Avg Recall: 0.0000
Avg F1 Score: 0.0000
[1   /1   ] TimeoutError: The chat request exceeded the timeout limit 