In [None]:
import chess
import torch
import dataset
from model import NNUE, load_checkpoint
from position import board_to_tensor

def mirror_board(board: chess.Board) -> chess.Board:
    flipped_board = board.copy()

    flipped_board.clear()
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece:
            flipped_piece = chess.Piece(
                piece.piece_type, 
                not piece.color  # Swap color
            )
            flipped_square = chess.square_mirror(square)  # Flip vertical position
            flipped_board.set_piece_at(flipped_square, flipped_piece)
    
    flipped_board.turn = not board.turn
    
    return flipped_board

test_set = "/home/seb/git/Stalemater2000/nnue/data/head_50k.jsonl.zst"

model = NNUE()
load_checkpoint("/home/seb/git/Stalemater2000/nnue/output/exp_2025-03-19_22-27-05/nnue_0001-5.pt", model)

def evaluate_board(board):
    boards = board_to_tensor(board).unsqueeze(0).float()
    black_to_move = torch.tensor([board.turn == chess.BLACK]).unsqueeze(0)
    with torch.no_grad():
        fixed_eval = model(boards, black_to_move).item()
    return fixed_eval

for i, position in enumerate(dataset.iterate_full_dataset(test_set)):
    if i > 100: 
        break

    board = chess.Board(position.fen)
    flipped = mirror_board(board)
    
    a = evaluate_board(board)
    b = evaluate_board(flipped)
    print(a, b, a + b)
            


Checkpoint loaded from /home/seb/git/Stalemater2000/nnue/output/exp_2025-03-19_22-27-05/nnue_0001-5.pt, no optimizer
31.469472885131836 -31.469472885131836 0.0
-424.92333984375 424.92333984375 0.0
-49.85052490234375 49.85052490234375 0.0
62.218177795410156 -62.218177795410156 0.0
2679.70654296875 -2679.70654296875 0.0
52.663543701171875 -52.663543701171875 0.0
-316.1607360839844 316.1607360839844 0.0
-21.753002166748047 21.753002166748047 0.0
9.862377166748047 -9.862377166748047 0.0
428.3424072265625 -428.3424072265625 0.0
1815.1990966796875 -1815.1990966796875 0.0
-214.9868621826172 214.9868621826172 0.0
4.7056565284729 -4.7056565284729 0.0
-240.4815673828125 240.4815673828125 0.0
171.32977294921875 -171.32977294921875 0.0
-192.5454559326172 192.5454559326172 0.0
516.3751220703125 -516.3751220703125 0.0
-317.9339904785156 317.9339904785156 0.0
67.3759765625 -67.3759765625 0.0
14.144407272338867 -14.144407272338867 0.0
271.7540283203125 -271.7540283203125 0.0
-31.323286056518555 31.323