In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import chess
from torch.utils.tensorboard import SummaryWriter
import time
import os
import numpy as np
torch.manual_seed(42)

In [None]:
class EarlyStopping:
    def __init__(self, tolerance=10, min_delta=10000):
        self.tolerance = tolerance
        self.counter = 0
        self.prev = False
        self.early_stop = False

        self.min_delta = min_delta
        self.min_train_loss = float('inf')


    def condition(self, train_loss, validation_loss):
        return abs(train_loss - validation_loss) <= self.min_delta and self.min_train_loss <= train_loss
    def __call__(self, train_loss, validation_loss):
        self.min_train_loss = min(self.min_train_loss, train_loss)
        if self.condition(train_loss, validation_loss):
            if self.prev == True:
                self.counter +=1
            else:
                self.counter = 1
            self.prev = True

            if self.counter >= self.tolerance:  
                self.early_stop = True
        else:
            self.prev = False

early_stopping = EarlyStopping(tolerance=3, min_delta=2)

train_loss = [
    6,7,6.5
]
validate_loss = [
    6,6,5
]

for i in range(len(train_loss)):

    early_stopping(train_loss[i], validate_loss[i])
    print(f"loss: {train_loss[i]} : {validate_loss[i]}")
    if early_stopping.early_stop:
        print("We are at epoch:", i)
        break



In [None]:
class NNUE(nn.Module):
    def __init__(self):
        super(NNUE, self).__init__()
        self.fc1 = nn.Linear(768, 8)
        self.fc2 = nn.Linear(8, 8)
        self.fc3 = nn.Linear(8, 1)

    def clipped_relu(self, x):
        return torch.clamp(x, 0, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.clipped_relu(x)
        x = self.fc2(x)
        x = self.clipped_relu(x)
        x = self.fc3(x)
        return x
    
class ChessDataset(Dataset):
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets
    
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, idx):
        feature = self.features[idx]
        target = self.targets[idx]

        feature = torch.tensor(feature, dtype=torch.int)
        target = torch.tensor(target, dtype=torch.float32)

        return feature, target

In [None]:
# path = f'./checkpoint/nnue_3_1978880d_512bs_300es_61e.pth'
# checkpoint = {
#     'epoch': epoch + 1,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'loss': loss
# }
# torch.save(checkpoint, path)

In [None]:
import chess
import chess.svg
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import time  # To see AI thinking time
from model_architecture.minimax import MinimaxNNUE

chessAI = MinimaxNNUE(depth=3)

test_cases = [
    # Normal Cases
    # {"name": "Opening Move", "fen": "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKB1R w KQkq - 0 1"},
    # {"name": "Pawn Trade", "fen": "rnbqkbnr/pppp1ppp/5n2/5p2/5P2/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 2"}, # Corrected move count
    # {"name": "Knight Development", "fen": "rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 2"},

    # # Special Cases
    # {"name": "Knight Fork", "fen": "rnbqkb1r/pppp1ppp/5n2/5p2/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 0 2"},
    # {"name": "K+P Endgame", "fen": "8/8/8/5k2/5p2/5P2/5K2/8 w - - 0 1"},

    # # Advantage Cases
    # {"name": "Passed Pawn", "fen": "8/7k/8/5p2/5P2/5K2/8/8 w - - 0 1"},
    # {"name": "Trapped King (Potential)", "fen": "rnbqkb1r/pppp1ppp/5n2/4p3/4P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 0 2"}, # Name clarified

    # # Pawn Structure Cases
    # {"name": "Doubled Pawns (g file)", "fen": "7k/8/8/5pp1/5p1P/5P2/5K2/8 w - - 0 1"}, # Clarified name
    # {"name": "Blocked Pawns", "fen": "8/7k/8/5p2/5P2/5p2/5K2/8 w - - 0 1"},
    # {"name": "Isolated Pawn (f file)", "fen": "8/7k/8/5p2/5P2/5K2/8/8 w - - 0 1"}, # Clarified name
    # {"name": "Combined Pawn Issues", "fen": "8/7k/8/4pp1p/5P1P/5K2/8/8 w - - 0 1"},

    # Checkmate/Tactics Cases
    {"name": "Checkmate in 1 (Scholar's)", "fen": "rnbqkbnr/p1pppppp/8/8/1pB1P3/5Q2/PPPP1PPP/RNB1K1NR w KQkq - 0 4"},
    {"name": "Trap (Fried Liver setup)", "fen": "r1bqkb1r/pppp1ppp/2n2n2/4p2Q/2B1P3/8/PPPP1PPP/RNB1K1NR w KQkq - 4 4"}
]

def display_svg_board(board, size=300): # Reduced size for side-by-side
    """Displays a chessboard using SVG."""
    return chess.svg.board(board, size=size)

# Dropdown to select the test case
test_case_names = [tc["name"] for tc in test_cases]
dropdown = widgets.Dropdown(
    options=test_case_names,
    description='Test Case:',
    disabled=False,
)

# Output area to display the boards and information
output_area = widgets.Output()

def on_dropdown_change(change):
    """This function runs when the dropdown value changes."""
    with output_area: # Direct output to the output_area widget
        clear_output(wait=True) # Clear previous output smoothly

        selected_name = change['new']
        selected_case = next((tc for tc in test_cases if tc["name"] == selected_name), None)

        if not selected_case:
            print(f"Error: Test case '{selected_name}' not found.")
            return

        fen = selected_case["fen"]
        initial_board = chess.Board(fen)

        display(HTML(f"<h3>Test Case: {selected_name}</h3>"))
        display(HTML(f"<b>Initial FEN:</b> {fen}"))

        # --- Run the AI ---
        display(HTML("<p><i>Running AI...</i></p>"))
        start_time = time.time()
        try:
            best_move = chessAI.find_best_move(initial_board)
            end_time = time.time()
            duration = end_time - start_time
            display(HTML(f"<p><i>AI finished in {duration:.2f} seconds.</i></p>"))

            if initial_board.is_game_over():
                display(HTML(f"<b>Game is already over: {initial_board.result()}</b>"))
                best_move = None # No move to make if game is over

            elif best_move:
                # Check if the move is legal before applying
                if best_move in initial_board.legal_moves:
                    ai_board = initial_board.copy() # Create a copy to apply the move
                    ai_board.push(best_move) # Apply the move

                    # Display initial and final boards side-by-side using HTML table
                    board_display = f"""
                    <table>
                        <tr>
                            <td style="text-align: center;"><b>Initial Board:</b><br>{display_svg_board(initial_board)}</td>
                            <td style="text-align: center;"><b>AI Move:</b> {initial_board.san(best_move)} ({best_move.uci()})<br><b>Board After AI Move:</b><br>{display_svg_board(ai_board)}</td>
                        </tr>
                    </table>
                    """
                    display(HTML(board_display))
                    display(HTML(f"<b>FEN After Move:</b> {ai_board.fen()}"))

                    if ai_board.is_checkmate():
                        display(HTML("<b style='color:red;'>CHECKMATE!</b>"))
                    elif ai_board.is_stalemate():
                        display(HTML("<b style='color:orange;'>STALEMATE!</b>"))
                    elif ai_board.is_insufficient_material():
                        display(HTML("<b style='color:grey;'>INSUFFICIENT MATERIAL!</b>"))
                    elif ai_board.is_check():
                        display(HTML("<b style='color:blue;'>CHECK!</b>"))

                else:
                    display(HTML(f"<b style='color:red;'>Error: AI proposed an illegal move: {best_move.uci()}</b>"))
                    # Optionally display legal moves for debugging:
                    # display(HTML(f"Legal moves were: {', '.join(m.uci() for m in board.legal_moves)}"))

            else:
                display(HTML("<b style='color:orange;'>AI found no move (or decided not to move?).</b>"))
                if initial_board.is_stalemate():
                    display(HTML("<b style='color:orange;'>Reason: STALEMATE!</b>"))
                elif initial_board.is_checkmate(): # Should technically not happen if no move found, but good check
                    display(HTML("<b style='color:red;'>Reason: CHECKMATE!</b>"))


        except Exception as e:
            display(HTML(f"<b style='color:red;'>Error during AI execution: {e}</b>"))

        display(HTML("<hr>")) # Separator

dropdown.observe(on_dropdown_change, names='value')

display(dropdown)
display(output_area)

# Trigger the display for the initial dropdown value
on_dropdown_change({'new': dropdown.value})