In [None]:
# # Updated train.py
# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader, Dataset
# import h5py
# import numpy as np
# from model import ChessNet
# import os

# class ChessDataset(Dataset):
#     def __init__(self, h5_path):
#         if not os.path.exists(h5_path):
#             raise FileNotFoundError(f"HDF5 file {h5_path} not found!")
            
#         with h5py.File(h5_path, 'r') as hf:
#             self.inputs = hf['inputs'][:]
#             self.policy = hf['policy'][:]
#             self.value = hf['value'][:]
            
#             if len(self.inputs) == 0:
#                 raise ValueError("Dataset is empty!")

#     def __len__(self):
#         return len(self.inputs)

#     def __getitem__(self, idx):
#         return (
#             torch.tensor(self.inputs[idx], dtype=torch.float32),
#             torch.tensor(self.policy[idx], dtype=torch.long),
#             torch.tensor(self.value[idx], dtype=torch.float32)
#         )

# def train_model(h5_path, num_epochs=10, batch_size=128):
#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
#     # Initialize model
#     model = ChessNet(num_blocks=6, channels=128).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
#     # Loss functions
#     policy_loss = nn.CrossEntropyLoss()
#     value_loss = nn.MSELoss()
    
#     # Data loading
#     try:
#         dataset = ChessDataset(h5_path)
#         print(f"Loaded dataset with {len(dataset)} samples")
#     except Exception as e:
#         print(f"Dataset error: {str(e)}")
#         return
    
#     if len(dataset) == 0:
#         print("Error: Dataset contains no samples!")
#         return
    
#     loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
#     # Training loop
#     for epoch in range(num_epochs):
#         total_loss = 0.0
#         for inputs, policies, values in loader:
#             inputs = inputs.to(device)
#             policies = policies.to(device)
#             values = values.to(device)
            
#             optimizer.zero_grad()
            
#             # Forward pass
#             policy_pred, value_pred = model(inputs)
            
#             # Calculate losses
#             p_loss = policy_loss(policy_pred, policies)
#             v_loss = value_loss(value_pred.squeeze(), values)
#             loss = p_loss + v_loss
            
#             # Backprop
#             loss.backward()
#             optimizer.step()
            
#             total_loss += loss.item()
        
#         print(f'Epoch {epoch+1}/{num_epochs} Loss: {total_loss/len(loader):.4f}')
    
#     # Save model
#     torch.save(model.state_dict(), 'chess_model.pth')

# if __name__ == '__main__':
#     # First convert PGN to HDF5
#     from pgn_converter import convert_pgn
    
#     try:
#         convert_pgn('master_games.pgn', 'chess_data.h5')
#     except Exception as e:
#         print(f"Conversion failed: {str(e)}")
#         exit(1)
    
#     # Then train the model
#     train_model('chess_data.h5')


Reading games...
Processing 25 games...
Conversion failed: name 'process_game' is not defined
Dataset error: Dataset is empty!


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import h5py
import numpy as np
import os
import chess
import chess.pgn
from tqdm.notebook import tqdm
import torch.nn.functional as F

# First, define the model
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

class ChessNet(nn.Module):
    def __init__(self, num_blocks=6, channels=128):
        super().__init__()
        # Input block
        self.input_block = nn.Sequential(
            nn.Conv2d(18, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )
        
        # Residual tower
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(channels) for _ in range(num_blocks)]
        )
        
        # Policy head
        self.policy_head = nn.Sequential(
            nn.Conv2d(channels, 2, 1),
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2*8*8, 4672)
        )
        
        # Value head
        self.value_head = nn.Sequential(
            nn.Conv2d(channels, 1, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(8*8, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # NHWC to NCHW
        x = self.input_block(x)
        x = self.res_blocks(x)
        
        policy = self.policy_head(x)
        value = self.value_head(x)
        
        return policy, value

# PGN conversion functions
def board_to_input_planes(board):
    """
    Convert a chess board to input planes (18 channels, 8x8)
    6 pieces * 2 colors * (1 for piece positions + 1 for attacks) = 24 planes
    But we'll simplify to just 12 piece position planes + 6 attack planes
    """
    piece_types = [chess.PAWN, chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN, chess.KING]
    colors = [chess.WHITE, chess.BLACK]
    
    # Initialize input planes (18 planes, 8x8 board)
    planes = np.zeros((8, 8, 18), dtype=np.float32)
    
    # Piece positions (12 planes: 6 piece types * 2 colors)
    for i, piece_type in enumerate(piece_types):
        for j, color in enumerate(colors):
            for square in chess.SquareSet(board.pieces(piece_type, color)):
                row, col = divmod(square, 8)
                planes[row, col, i + j*6] = 1.0
    
    # Attacks (6 planes: 3 for white attacks, 3 for black attacks)
    # We'll use 3 levels of attack: 1 attacker, 2 attackers, 3+ attackers
    for row in range(8):
        for col in range(8):
            square = row * 8 + col
            
            # White attacks
            attackers = board.attackers(chess.WHITE, square)
            num_attackers = len(list(attackers))
            if num_attackers >= 1:
                planes[row, col, 12] = 1.0
            if num_attackers >= 2:
                planes[row, col, 13] = 1.0
            if num_attackers >= 3:
                planes[row, col, 14] = 1.0
                
            # Black attacks
            attackers = board.attackers(chess.BLACK, square)
            num_attackers = len(list(attackers))
            if num_attackers >= 1:
                planes[row, col, 15] = 1.0
            if num_attackers >= 2:
                planes[row, col, 16] = 1.0
            if num_attackers >= 3:
                planes[row, col, 17] = 1.0
    
    return planes

def move_to_policy_index(move):
    """
    Convert a move to a policy index (one-hot vector of size 4672)
    We use 8x8x73 encoding:
    - 56 queen moves (7 in each direction)
    - 8 knight moves
    - 9 underpromotions (3 piece types * 3 directions)
    Total: 73 moves per square = 73 * 64 = 4672 possible moves
    """
    from_square = move.from_square
    to_square = move.to_square
    
    from_rank, from_file = divmod(from_square, 8)
    to_rank, to_file = divmod(to_square, 8)
    
    # Calculate direction
    rank_diff = to_rank - from_rank
    file_diff = to_file - from_file
    
    # Queen moves (straight lines and diagonals)
    if (rank_diff == 0 or file_diff == 0 or abs(rank_diff) == abs(file_diff)):
        # Determine direction
        if rank_diff > 0 and file_diff == 0:  # North
            direction = 0
        elif rank_diff > 0 and file_diff > 0:  # Northeast
            direction = 1
        elif rank_diff == 0 and file_diff > 0:  # East
            direction = 2
        elif rank_diff < 0 and file_diff > 0:  # Southeast
            direction = 3
        elif rank_diff < 0 and file_diff == 0:  # South
            direction = 4
        elif rank_diff < 0 and file_diff < 0:  # Southwest
            direction = 5
        elif rank_diff == 0 and file_diff < 0:  # West
            direction = 6
        elif rank_diff > 0 and file_diff < 0:  # Northwest
            direction = 7
        
        # Distance (1-7)
        distance = max(abs(rank_diff), abs(file_diff))
        
        # Index within the move encoding
        move_index = direction * 7 + (distance - 1)
    
    # Knight moves
    elif (abs(rank_diff), abs(file_diff)) in [(1, 2), (2, 1)]:
        # Convert knight move to index (8 possible knight moves)
        if (rank_diff, file_diff) == (2, 1):
            knight_index = 0
        elif (rank_diff, file_diff) == (1, 2):
            knight_index = 1
        elif (rank_diff, file_diff) == (-1, 2):
            knight_index = 2
        elif (rank_diff, file_diff) == (-2, 1):
            knight_index = 3
        elif (rank_diff, file_diff) == (-2, -1):
            knight_index = 4
        elif (rank_diff, file_diff) == (-1, -2):
            knight_index = 5
        elif (rank_diff, file_diff) == (1, -2):
            knight_index = 6
        elif (rank_diff, file_diff) == (2, -1):
            knight_index = 7
        
        # Knight moves start after queen moves
        move_index = 56 + knight_index
    
    # Underpromotions (excluding queen promotion which is counted as a queen move)
    elif move.promotion is not None and move.promotion != chess.QUEEN:
        # Direction (straight, left diagonal, right diagonal)
        if file_diff == 0:
            direction = 0
        elif file_diff < 0:
            direction = 1
        else:
            direction = 2
        
        # Piece type (knight, bishop, rook) - 1 (to get 0, 1, 2)
        piece_idx = [chess.KNIGHT, chess.BISHOP, chess.ROOK].index(move.promotion)
        
        # Underpromotions start after queen and knight moves
        move_index = 56 + 8 + piece_idx * 3 + direction
    
    # Any move we can't categorize (shouldn't happen with legal moves)
    else:
        # Default to the first move
        move_index = 0
    
    # Final policy index: 73 move types per square
    return from_square * 73 + move_index

def process_game(game):
    """Process a game, returning (input_planes, policy, value) for training positions"""
    inputs = []
    policies = []
    values = []
    
    board = game.board()
    result = game.headers.get("Result", "*")
    
    # Determine game outcome
    if result == "1-0":
        white_win = 1.0
    elif result == "0-1":
        white_win = -1.0
    else:  # Draw or unfinished
        white_win = 0.0
    
    for move_num, move in enumerate(game.mainline_moves()):
        # Skip first 5 moves (10 half-moves) to focus on middlegame positions
        if move_num < 10:
            board.push(move)
            continue
            
        # Extract features before the move is made
        input_planes = board_to_input_planes(board)
        
        # Extract policy (the actual move that was played)
        policy_index = move_to_policy_index(move)

        # Value: perspective of the player to move
        value = white_win if board.turn == chess.WHITE else -white_win
        
        # Store the training example
        inputs.append(input_planes)
        policies.append(policy_index)
        values.append(value)
        
        # Make the move on the board
        board.push(move)
        
        # Limit samples per game
        if len(inputs) >= 10:
            break
    
    return inputs, policies, values

def convert_pgn(pgn_file, output_h5):
    """Convert PGN file to HDF5 format for training"""
    if not os.path.exists(pgn_file):
        raise FileNotFoundError(f"PGN file {pgn_file} not found!")
    
    print(f"Reading games...")
    with open(pgn_file) as f:
        # Count games first (to show progress)
        game_count = 0
        pos = f.tell()
        while True:
            game = chess.pgn.read_game(f)
            if game is None:
                break
            game_count += 1
            
        # Reset file position
        f.seek(pos)
        print(f"Processing {game_count} games...")
        
        all_inputs = []
        all_policies = []
        all_values = []
        
        # Process each game
        for _ in tqdm(range(game_count)):
            game = chess.pgn.read_game(f)
            if game is None:
                break
                
            try:
                inputs, policies, values = process_game(game)
                all_inputs.extend(inputs)
                all_policies.extend(policies)
                all_values.extend(values) 
            except Exception as e:
                print(f"Error processing game: {str(e)}")
                continue
    
    # Convert to numpy arrays
    all_inputs = np.array(all_inputs, dtype=np.float32)
    all_policies = np.array(all_policies, dtype=np.int64)
    all_values = np.array(all_values, dtype=np.float32)
    
    print(f"Total positions: {len(all_inputs)}")
    
    # Save to HDF5
    with h5py.File(output_h5, 'w') as hf:
        hf.create_dataset('inputs', data=all_inputs)
        hf.create_dataset('policy', data=all_policies)
        hf.create_dataset('value', data=all_values)
    
    print(f"Data saved to {output_h5}")

# Dataset class for training
class ChessDataset(Dataset):
    def __init__(self, h5_path):
        if not os.path.exists(h5_path):
            raise FileNotFoundError(f"HDF5 file {h5_path} not found!")
            
        with h5py.File(h5_path, 'r') as hf:
            self.inputs = hf['inputs'][:]
            self.policy = hf['policy'][:]
            self.value = hf['value'][:]
            
            if len(self.inputs) == 0:
                raise ValueError("Dataset is empty!")

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return (
            torch.tensor(self.inputs[idx], dtype=torch.float32),
            torch.tensor(self.policy[idx], dtype=torch.long),
            torch.tensor(self.value[idx], dtype=torch.float32)
        )

def train_model(h5_path, model_save_path='chess_model.pth', num_epochs=10, batch_size=128, learning_rate=0.001):
    import torch.nn.functional as F  # Import here to ensure it's available
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize model
    model = ChessNet(num_blocks=6, channels=128).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Loss functions
    policy_loss = nn.CrossEntropyLoss()
    value_loss = nn.MSELoss()
    
    # Data loading
    try:
        dataset = ChessDataset(h5_path)
        print(f"Loaded dataset with {len(dataset)} samples")
    except Exception as e:
        print(f"Dataset error: {str(e)}")
        return
    
    if len(dataset) == 0:
        print("Error: Dataset contains no samples!")
        return
    
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        policy_accuracy = 0.0
        value_mse = 0.0
        
        for inputs, policies, values in tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs = inputs.to(device)
            policies = policies.to(device)
            values = values.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            policy_pred, value_pred = model(inputs)
            
            # Calculate losses
            p_loss = policy_loss(policy_pred, policies)
            v_loss = value_loss(value_pred.squeeze(), values)
            loss = p_loss + v_loss
            
            # Backprop
            loss.backward()
            optimizer.step()
            
            # Track metrics
            total_loss += loss.item()
            
            # Policy accuracy (top-1)
            _, predicted = torch.max(policy_pred, 1)
            policy_accuracy += (predicted == policies).sum().item() / len(policies)
            
            # Value MSE
            value_mse += v_loss.item()
        
        # Calculate epoch metrics
        avg_loss = total_loss / len(loader)
        avg_policy_acc = policy_accuracy / len(loader)
        avg_value_mse = value_mse / len(loader)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Loss: {avg_loss:.4f}')
        print(f'  Policy Accuracy: {avg_policy_acc:.4f}')
        print(f'  Value MSE: {avg_value_mse:.4f}')
    
    # Save model
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

# Configuration variables (modify these as needed)
pgn_file = 'master_games.pgn'  # Path to your PGN file
h5_file = 'chess_data.h5'      # Path to save/load HDF5 file
model_file = 'chess_model.pth' # Path to save model
num_epochs = 10               # Number of epochs to train
batch_size = 128              # Batch size for training
learning_rate = 0.001         # Learning rate

# Run the conversion (uncomment to run)
# convert_pgn(pgn_file, h5_file)

# Train the model (uncomment to run)
train_model(h5_file, model_file, num_epochs, batch_size, learning_rate)

Using device: cuda
Loaded dataset with 247 samples


Epoch 1/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 1/10:
  Loss: 9.4915
  Policy Accuracy: 0.0000
  Value MSE: 1.0789


Epoch 2/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 2/10:
  Loss: 8.9740
  Policy Accuracy: 0.0120
  Value MSE: 1.0008


Epoch 3/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 3/10:
  Loss: 8.5641
  Policy Accuracy: 0.0573
  Value MSE: 0.9927


Epoch 4/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 4/10:
  Loss: 8.2078
  Policy Accuracy: 0.1429
  Value MSE: 0.9905


Epoch 5/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 5/10:
  Loss: 7.8487
  Policy Accuracy: 0.3222
  Value MSE: 0.9845


Epoch 6/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 6/10:
  Loss: 7.4369
  Policy Accuracy: 0.4649
  Value MSE: 0.9850


Epoch 7/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 7/10:
  Loss: 7.0095
  Policy Accuracy: 0.6312
  Value MSE: 0.9603


Epoch 8/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 8/10:
  Loss: 6.6224
  Policy Accuracy: 0.7439
  Value MSE: 0.9403


Epoch 9/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 9/10:
  Loss: 6.2432
  Policy Accuracy: 0.7868
  Value MSE: 0.9106


Epoch 10/10:   0%|          | 0/2 [00:00<?, ?it/s]

Epoch 10/10:
  Loss: 5.8657
  Policy Accuracy: 0.8496
  Value MSE: 0.8605
Model saved to chess_model.pth
