In [3]:
import os
import sys
import torch
import random
import torch.optim as optim
import chess
from models.models import MoveScorer
import ast
from alter_move_prob_train.alter_move_prob_nn import AlterMoveProbNN
from common.board_information import phase_of_game

# Initialise the AlterMoveProbNN model
model = AlterMoveProbNN()

# load model
model.load_state_dict(torch.load('alter_move_prob_train/data/alter_move_prob_nn_best.pth', weights_only=True))

# Print the model parameters
print("Model parameters:")
for name, param in model.named_parameters():
    print(f"{name}: {param.item():.4f}")

# Create an example board position
# Create a random chess position
def create_random_position():
    # Start with a standard board
    board = chess.Board()
    
    # Make a random number of random moves (between 5 and 20)
    num_moves = random.randint(5, 20)
    
    for _ in range(num_moves):
        legal_moves = list(board.legal_moves)
        if not legal_moves:  # No legal moves (checkmate or stalemate)
            break
        
        # Make a random move
        move = random.choice(legal_moves)
        board.push(move)
        
        # Stop if the game is over
        if board.is_game_over():
            break
    
    return board

# Set a random seed for reproducibility
random.seed(42)

# Generate a random position
board = create_random_position()

print("\nExample board position:")
print(board)
print(f"Game phase: {phase_of_game(board)}")

# generate move dic using middlegame scorer
# Create a simple move dictionary for testing
# We'll use a midgame move scorer to generate realistic probabilities


# Load the midgame move scorer
midgame_scorer = MoveScorer("models/model_weights/piece_selector_midgame_weights.pth", "models/model_weights/piece_to_midgame_weights.pth")

# Get move dictionary from the scorer
_, move_dic = midgame_scorer.get_move_dic(board, san=False, top=100)

# Convert to regular dictionary if it's not already
if not isinstance(move_dic, dict):
    move_dic = {k: float(v) for k, v in move_dic.items()}


# No previous boards for this example
prev_board = None
prev_prev_board = None

# Run the model on the example
altered_move_dic, log = model(move_dic, board, prev_board, prev_prev_board)

# Print the results
print("\nOriginal move dictionary:")
for move, prob in move_dic.items():
    print(f"{board.san(chess.Move.from_uci(move))}: {prob:.4f}")

print("\nAltered move dictionary:")
for move, prob in altered_move_dic.items():
    if isinstance(prob, torch.Tensor):
        prob_value = prob.item()
    else:
        prob_value = float(prob)
    print(f"{board.san(chess.Move.from_uci(move))}: {prob_value:.4f}")

print("\nLog of alterations:")
print(log)

# Create an optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Zero the gradients
optimizer.zero_grad()

# Calculate loss (using a simple example)
# In a real training scenario, this would be based on the true move probability
true_move = random.choice(list(move_dic.keys()))
if true_move in altered_move_dic:
    true_move_prob = altered_move_dic[true_move]
    # Make sure it's a tensor with gradient tracking
    if not isinstance(true_move_prob, torch.Tensor):
        true_move_prob = torch.tensor(true_move_prob, dtype=torch.float, requires_grad=True)
    loss = -torch.log(true_move_prob)
else:
    # If the true move is not in the altered move dictionary, assign a high loss
    loss = torch.tensor(10.0, dtype=torch.float, requires_grad=True)

print(f"\nLoss for this example: {loss.item():.4f}")

# Backward pass
loss.backward()

# Print gradients before update
print("\nGradients before update:")
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: {param.grad.item():.6f}")

# Update parameters
optimizer.step()

# Print parameters after update
print("\nParameters after optimisation step:")
for name, param in model.named_parameters():
    print(f"{name}: {param.item():.4f}")



Model parameters:
weird_move_sd_opening: -0.1913
weird_move_sd_midgame: -0.3328
weird_move_sd_endgame: -0.0657
protect_king_sf: 0.7614
capture_en_pris_sf: 0.7933
break_pin_sf: 1.2306
capture_sf: 1.1687
capture_sf_king_danger: 1.1475
capturable_sf: 0.8171
solo_factor_sf: 1.3324
threatened_lvl_diff_sf: 0.4485
check_sf: 1.2022
takeback_sf: 2.6469
new_threatened_sf: 1.6287
exchange_sf: 1.4049
exchange_k_danger_sf: 1.1235
passed_pawn_end_sf: 3.3780
repeat_sf: 0.1000
interesting_move_threshold: -1.0406

Example board position:
r n . q k b n r
. p p . p p p .
p . . p . . . .
. . . . . . . p
. . . . . . b .
. . . . . P . N
P P P P P . P P
R N B Q K B . R
Game phase: opening

Original move dictionary:
fxg4: 0.3994
d3: 0.1553
d4: 0.1074
e3: 0.0767
e4: 0.0620
g3: 0.0581
c3: 0.0318
Kf2: 0.0264
Nf2: 0.0162
Nc3: 0.0153
c4: 0.0105
Nf4: 0.0039
Ng5: 0.0030
Ng1: 0.0026
b3: 0.0025
Na3: 0.0023
b4: 0.0012
a4: 0.0008
Rg1: 0.0007
a3: 0.0005
f4: 0.0001

Altered move dictionary:
fxg4: 0.4294
d3: 0.1355
d4: 0.0