# NNUE for Chess AI (Sparse Feature Version)

This notebook implements an **Efficiently Updatable Neural Network (NNUE)** to evaluate chess positions using a sparse feature representation. The model is trained on a dataset of chess positions in FEN format and their corresponding evaluations, leveraging PyTorch for neural network operations. The goal is to provide a clean, readable, and well-documented implementation with an interactive UI for parameter adjustment.

Reference documentation: [NNUE PyTorch Docs](https://github.com/official-stockfish/nnue-pytorch/blob/master/docs/nnue.md)

## 1. Setup

In [4]:
# %pip install chess

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split
import chess
from torch.utils.tensorboard import SummaryWriter
import ipywidgets as widgets
from IPython.display import display
import time
import os
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x1f77e656310>

## 3. Data Loading

We load the chess dataset from a CSV file containing FEN strings and their evaluations.

In [9]:
# Load the CSV data (adjust the path as needed)
df = pd.read_csv('../assets/chess-data/fen/train.csv')
print('Original shape:', df.shape)

# Limit the dataset size for faster experimentation (optional)
df = df[:10]
print('Used shape:', df.shape)

# Extract FENs and evaluations
fens = df['FEN'].tolist()
targets = df['Evaluation'].tolist()

Original shape: (1979383, 2)
Used shape: (10, 2)


## 4. Data Preprocessing

We convert FEN strings into sparse feature indices representing piece positions for white and black, along with the side to move.

In [10]:
def get_feature_index(color, piece_type, square):
    """Calculate a unique feature index for a piece on a square.
    
    Args:
        color (int): 0 for white, 1 for black.
        piece_type (int): 0 (pawn) to 5 (king).
        square (int): 0 to 63 (board square index).
    
    Returns:
        int: Feature index (0-767).
    """
    return (color * 6 + piece_type) * 64 + square

def preprocess_position(fen):
    """Convert a FEN string to sparse feature indices and side-to-move indicator.
    
    Args:
        fen (str): FEN string representing a chess position.
    
    Returns:
        tuple: (white_features, black_features, side_to_move)
            - white_features: List of feature indices for white pieces (0-383).
            - black_features: List of feature indices for black pieces (384-767).
            - side_to_move: 1 if white to move, 0 if black to move.
    """
    board = chess.Board(fen)
    features = []
    
    # Extract features for each piece on the board
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is not None:
            color = 0 if piece.color == chess.WHITE else 1
            piece_type = piece.piece_type - 1  # Map 1-6 (pawn-king) to 0-5
            feature_index = get_feature_index(color, piece_type, square)
            features.append(feature_index)
    
    # Split features by color
    features_white = [f for f in features if f < 384]  # White pieces: indices 0-383
    features_black = [f for f in features if f >= 384]  # Black pieces: indices 384-767
    
    # Side to move indicator
    side_to_move = 1 if board.turn == chess.WHITE else 0
    
    return features_white, features_black, side_to_move

### Explanation
- **Feature Indexing**: Each piece on a square is assigned a unique index from 0 to 767. White pieces occupy 0-383, and black pieces occupy 384-767.
- **Preprocessing**: The `preprocess_position` function parses a FEN string, identifies pieces on the board, computes their feature indices, and separates them by color. It also determines the side to move.

## 5. Dataset and DataLoader

We define a custom dataset class and a collate function to handle variable-length feature lists in batches.

In [38]:
class ChessDataset(Dataset):
    """Custom dataset for chess positions and evaluations."""
    def __init__(self, fens, targets):
        self.fens = fens
        self.targets = targets
    
    def __len__(self):
        return len(self.fens)
    
    def __getitem__(self, idx):
        fen = self.fens[idx]
        target = self.targets[idx]
        fw, fb, stm = preprocess_position(fen)
        return fw, fb, stm, target

def collate_fn(batch):
    """Collate function to process batches with variable-length feature lists.
    
    Args:
        batch: List of (features_white, features_black, side_to_move, target) tuples.
    
    Returns:
        tuple: Tensors for white features, white offsets, black features, black offsets,
               side to move, and targets.
    """
    features_white, features_black, side_to_move, targets = [], [], [], []
    for fw, fb, stm, t in batch:
        features_white.extend(fw)
        features_black.extend(fb)
        side_to_move.append(stm)
        targets.append(t)
    
    # Convert lists to tensors
    fw_tensor = torch.tensor(features_white, dtype=torch.long)
    fb_tensor = torch.tensor(features_black, dtype=torch.long)
    offsets_white = torch.tensor([0] + [len(fw) for fw, _, _, _ in batch[:-1]], dtype=torch.long).cumsum(0)
    offsets_black = torch.tensor([0] + [len(fb) for _, fb, _, _ in batch[:-1]], dtype=torch.long).cumsum(0)
    stm_tensor = torch.tensor(side_to_move, dtype=torch.long)
    target_tensor = torch.tensor(targets, dtype=torch.float)

    print(batch)
    return fw_tensor, offsets_white, fb_tensor, offsets_black, stm_tensor, target_tensor

### Explanation
- **ChessDataset**: A custom PyTorch dataset that returns preprocessed features (white and black indices, side to move) and the evaluation target for each FEN string.
- **collate_fn**: Combines variable-length feature lists into tensors suitable for batch processing. Offsets track the start of each position’s features within the concatenated tensors.

## 6. Data Splitting and Loading

In [39]:
# Extract parameter values from sliders
BATCH_SIZE = 3
NUM_EPOCHS = 500

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    fens, targets, test_size=0.2, random_state=42
)
# Create datasets
train_dataset = ChessDataset(X_train, y_train)
test_dataset = ChessDataset(X_test, y_test)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=False)

### Explanation
- **Splitting**: The dataset is divided into 80% training and 20% testing sets with a fixed random seed for reproducibility.
- **DataLoaders**: Batches are created with the size specified by the UI slider. The training loader shuffles data for better generalization.

## 7. Model Definition

We define the NNUE model using an `EmbeddingBag` layer for sparse features and fully connected layers for evaluation.

In [10]:
class NNUE(nn.Module):
    """Efficiently Updatable Neural Network for chess position evaluation.
    
    Args:
        num_features (int): Total number of possible features (768).
        hidden_size (int): Size of the embedding output (256).
        hidden_size2 (int): Size of the first hidden layer (32).
        hidden_size3 (int): Size of the second hidden layer (32).
    """
    def __init__(self, num_features=768, hidden_size=256, hidden_size2=32, hidden_size3=32):
        super(NNUE, self).__init__()
        self.embedding = nn.EmbeddingBag(num_features, hidden_size, mode='sum')
        self.fc1 = nn.Linear(2 * hidden_size, hidden_size2)
        self.fc2 = nn.Linear(hidden_size2, hidden_size3)
        self.fc3 = nn.Linear(hidden_size3, 1)
    
    def forward(self, features_white, offsets_white, features_black, offsets_black, side_to_move):
        """Forward pass to compute the evaluation score.
        
        Args:
            features_white (Tensor): Indices of white piece features.
            offsets_white (Tensor): Offsets for white features in the batch.
            features_black (Tensor): Indices of black piece features.
            offsets_black (Tensor): Offsets for black features in the batch.
            side_to_move (Tensor): 1 if white to move, 0 if black to move.
        
        Returns:
            Tensor: Predicted evaluation score for each position in the batch.
        """
        # Sum embeddings for white and black pieces
        white_sum = self.embedding(features_white, offsets_white)  # Shape: (batch_size, hidden_size)
        black_sum = self.embedding(features_black, offsets_black)  # Shape: (batch_size, hidden_size)
        
        # Assign "us" and "them" based on side to move
        side_to_move = side_to_move.bool()
        us_sum = torch.where(side_to_move[:, None], white_sum, black_sum)
        them_sum = torch.where(side_to_move[:, None], black_sum, white_sum)
        
        # Concatenate features and pass through fully connected layers
        input_vector = torch.cat([us_sum, them_sum], dim=1)  # Shape: (batch_size, 2 * hidden_size)
        x = torch.clamp(self.fc1(input_vector), 0, 1)  # Clipped ReLU: [0, 1]
        x = torch.clamp(self.fc2(x), 0, 1)
        output = self.fc3(x)  # Linear output
        return output

### Explanation
- **EmbeddingBag**: Efficiently computes the sum of embeddings for active features (piece positions) per side.
- **Forward Pass**: Combines white and black embeddings based on the side to move, concatenates them, and processes them through three fully connected layers with clipped ReLU activations (except the final layer).

## 8. Training Setup

In [None]:
# Initialize the model, optimizer, loss function, and TensorBoard writer
model = NNUE()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
writer = SummaryWriter('runs/nnue_2_experiment')
os.makedirs('checkpoint',exist_ok=True)

## 9. Training Loop and Evaluation

In [None]:
# Training loop with validation and parameter monitoring
best_val_loss = float('inf')  # For tracking the best validation loss
patience = 10  # For early stopping
patience_counter = 0

for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    
    # Training phase
    model.train()  # Set model to training mode
    train_loss = 0.0
    for batch in train_loader:
        fw, ow, fb, ob, stm, targets = batch
        optimizer.zero_grad()  # Clear previous gradients
        outputs = model(fw, ow, fb, ob, stm)  # Forward pass
        loss = criterion(outputs.squeeze(), targets)  # Compute loss
        loss.backward()  # Backpropagation
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()  # Update weights
        train_loss += loss.item()  # Accumulate loss
    
    # Compute and log average training loss
    avg_train_loss = train_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_train_loss, epoch)
    
    # Validation phase
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:  # Assuming test_loader is your validation set
            fw, ow, fb, ob, stm, targets = batch
            outputs = model(fw, ow, fb, ob, stm)
            loss = criterion(outputs.squeeze(), targets)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(test_loader)
    writer.add_scalar('Loss/val', avg_val_loss, epoch)
    
    # Log model parameters and gradients (every 5 epochs)
    if (epoch + 1) % 5 == 0:
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch)
            if param.grad is not None:
                writer.add_histogram(f"{name}/grad", param.grad, epoch)
    
    # Save model checkpoint if validation loss improves
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), './checkpoint/nnue_2_512bs_2000epochs.pth')
        patience_counter = 0  # Reset patience
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break
    
    end_time = time.time()
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Time: {round(end_time - start_time)} s")

writer.close()

In [3]:
import torch
import torch.nn as nn
import chess

# Define the NNUE model class (same as in your notebook)
class NNUE(nn.Module):
    def __init__(self, num_features=768, hidden_size=256, hidden_size2=32, hidden_size3=32):
        super(NNUE, self).__init__()
        self.embedding = nn.EmbeddingBag(num_features, hidden_size, mode='sum')
        self.fc1 = nn.Linear(2 * hidden_size, hidden_size2)
        self.fc2 = nn.Linear(hidden_size2, hidden_size3)
        self.fc3 = nn.Linear(hidden_size3, 1)
    
    def forward(self, features_white, offsets_white, features_black, offsets_black, side_to_move):
        white_sum = self.embedding(features_white, offsets_white)
        black_sum = self.embedding(features_black, offsets_black)
        
        side_to_move = side_to_move.bool()
        us_sum = torch.where(side_to_move[:, None], white_sum, black_sum)
        them_sum = torch.where(side_to_move[:, None], black_sum, white_sum)
        
        input_vector = torch.cat([us_sum, them_sum], dim=1)
        x = torch.clamp(self.fc1(input_vector), 0, 1)
        x = torch.clamp(self.fc2(x), 0, 1)
        output = self.fc3(x)
        return output

# Feature index calculation
def get_feature_index(color, piece_type, square):
    return (color * 6 + piece_type) * 64 + square

# Preprocess FEN to sparse features
def preprocess_position(fen):
    board = chess.Board(fen)
    features = []
    
    for square in chess.SQUARES:
        piece = board.piece_at(square)
        if piece is not None:
            color = 0 if piece.color == chess.WHITE else 1
            piece_type = piece.piece_type - 1  # Map 1-6 (pawn-king) to 0-5
            feature_index = get_feature_index(color, piece_type, square)
            features.append(feature_index)
    
    features_white = [f for f in features if f < 384]
    features_black = [f for f in features if f >= 384]
    side_to_move = 1 if board.turn == chess.WHITE else 0
    
    return features_white, features_black, side_to_move

# Load the saved model
model = NNUE()
model.load_state_dict(torch.load('./checkpoint/nnue_2_512bs_2000epochs.pth', weights_only=True))
model.eval()

# Example FEN (starting position)
example_fen = "rnbqkbnr/pppppppp/5n2/8/8/5N2/PPPPPPPP/RNBQKB1R w KQkq - 0 1"
print(f"Evaluating position: {example_fen}")

# Preprocess the FEN
features_white, features_black, side_to_move = preprocess_position(example_fen)
def collate_fn(batch):
    """Collate function to process batches with variable-length feature lists.
    
    Args:
        batch: List of (features_white, features_black, side_to_move, target) tuples.
    
    Returns:
        tuple: Tensors for white features, white offsets, black features, black offsets,
               side to move, and targets.
    """
    features_white, features_black, side_to_move, targets = [], [], [], []
    for fw, fb, stm, t in batch:
        features_white.extend(fw)
        features_black.extend(fb)
        side_to_move.append(stm)
        targets.append(t)
    
    # Convert lists to tensors
    fw_tensor = torch.tensor(features_white, dtype=torch.long)
    fb_tensor = torch.tensor(features_black, dtype=torch.long)
    offsets_white = torch.tensor([0] + [len(fw) for fw, _, _, _ in batch[:-1]], dtype=torch.long).cumsum(0)
    offsets_black = torch.tensor([0] + [len(fb) for _, fb, _, _ in batch[:-1]], dtype=torch.long).cumsum(0)
    stm_tensor = torch.tensor(side_to_move, dtype=torch.long)
    target_tensor = torch.tensor(targets, dtype=torch.float)
    
    return fw_tensor, offsets_white, fb_tensor, offsets_black, stm_tensor, target_tensor

# Convert to tensors (batch size of 1)
fw_tensor = torch.tensor(features_white, dtype=torch.long).unsqueeze(0)  # Add batch dimension
fb_tensor = torch.tensor(features_black, dtype=torch.long).unsqueeze(0)
offsets_white = torch.tensor([0]+[17], dtype=torch.long)  # Single position, offset is 0
offsets_black = torch.tensor([0]+[18], dtype=torch.long)
stm_tensor = torch.tensor([side_to_move], dtype=torch.long)

# Run the model
with torch.no_grad():
    output = model(fw_tensor, offsets_white, fb_tensor, offsets_black, stm_tensor)

# Print the result
evaluation = output.item()
print(f"Model evaluation: {evaluation:.4f}")

# Interpret the result
if evaluation > 0:
    print("Positive score indicates an advantage for White.")
elif evaluation < 0:
    print("Negative score indicates an advantage for Black.")
else:
    print("Score of 0 indicates a balanced position.")


Evaluating position: rnbqkbnr/pppppppp/5n2/8/8/5N2/PPPPPPPP/RNBQKB1R w KQkq - 0 1


ValueError: if input is 2D, then offsets has to be None, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type <class 'torch.Tensor'>

## 10. TensorBoard Visualization

You can visualize the training progress using TensorBoard.

In [15]:
# # Load TensorBoard extension and display logs
# %load_ext tensorboard
# %tensorboard --logdir runs --port 8001