In [1]:
import chess
import RL_utils

In [2]:
# Our existing CNN model
ref_model_path = "models/TORCH_250EPOCH_DoubleHead.pth"

## Load stockfish

In [3]:
from RL_utils import PositionEvaluator

In [4]:
stockfish_path = "models/stockfish/stockfish-windows-x86-64-avx2.exe"
evaluator = PositionEvaluator(stockfish_path, elo_rating=1400)

Stockfish initialized with ELO: 1400


In [5]:
board = chess.Board()
board.push_uci("e2e4")
board.push_uci("h7h6")

Move.from_uci('h7h6')

In [6]:
eval_score = evaluator.evaluate_position(board)
print(f"Position evaluation: {eval_score}")

Position evaluation: 0.83


In [7]:
best_moves = evaluator.get_best_moves(board, 3)
print(f"Best moves: {best_moves}")

Best moves: [{'Move': 'd2d4', 'Centipawn': 83, 'Mate': None}, {'Move': 'g1f3', 'Centipawn': 71, 'Mate': None}, {'Move': 'b1c3', 'Centipawn': 68, 'Mate': None}]


## Load our CNN trained model

In [None]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model = reference_model.to(device)

In [None]:
# import torch
# dummy_input = torch.randn(1, 19, 8, 8).to(device)
# with torch.no_grad():
#     output = model(dummy_input)
#     print(f"Model output shape: {output.shape}")  # Should be [1, 4288]
#     print("Model loaded successfully!")

AttributeError: 'tuple' object has no attribute 'shape'

## Load the training data

In [None]:
pgn_file = "games/lichess_db_2016-04.pgn"
positions = []
positions += RL_utils.extract_varied_positions(pgn_file, num_positions=1000)
print(f"Total positions loaded: {len(positions)}")

Extracted 1000 middle game positions
Total positions loaded: 1000


In [4]:
# Test creating a batch
# if positions:
#     board_tensors, legal_masks, boards = RL_utils.create_training_batch(positions, batch_size=4)
#     print(f"Batch shapes:")
#     print(f"  Board tensors: {board_tensors.shape}")
#     print(f"  Legal masks: {legal_masks.shape}")
#     print(f"  Number of boards: {len(boards)}")
    
#     # Show a sample position
#     print(f"\nSample position FEN: {boards[0].fen()}")

In [5]:
# boards[0]
# eval_score = evaluator.evaluate_position(boards[0])
# print(f"Position evaluation: {eval_score}")

## Training Loop

In [4]:
import random
import tqdm
import torch
import numpy as np
import torch.nn.functional as F

In [None]:
# Training Hyperparameters
LEARNING_RATE = 0.001
BATCH_SIZE = 64
EPOCHS = 5

# Self-play
NUM_SELF_PLAY_GAMES = 100
MAX_GAME_MOVES = 200
TEMPERATURE = 1.0

# Model Saving
MODEL_SAVE_PATH = "models/dual_head_model.pth"

In [6]:
def generate_self_play_data(model, device, start_positions, num_games=NUM_SELF_PLAY_GAMES):
    """Generate self-play data with error handling"""
    data = []
    successful_games = 0
    
    for i in range(num_games):
        try:
            board = random.choice(start_positions)
            # game_history, result = RL_utils.play_self_play_game(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            game_history, result = RL_utils.play_game_with_mcts(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            
            if len(game_history) > 0:  # Only add if we have valid data
                data.append((game_history, result))
                successful_games += 1
                
        except Exception as e:
            print(f"Error in game {i}: {e}")
            continue
    
    print(f"Generated {successful_games} successful games out of {num_games} attempts")
    return data

In [None]:
def train_on_self_play(model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, legal_mask, move_idx, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            all_examples.append((board_tensor, legal_mask, move_idx, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        epoch_losses = []
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                masks = torch.stack([ex[1] for ex in batch]).to(device)
                move_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.long).to(device)
                value_targets = torch.tensor([ex[3] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss(policy_logits, value_preds, move_targets, value_targets, masks)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                epoch_losses.append(loss.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if epoch_losses:
            avg_loss = np.mean(epoch_losses)
            print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}")
        else:
            print(f"Epoch {epoch+1}: No valid batches")


In [None]:
import matplotlib.pyplot as plt

def train_on_mcts(reference_model, model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, policy_vector, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            policy_tensor = torch.tensor(policy_vector, dtype=torch.float32)
            all_examples.append((board_tensor, policy_tensor, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    kl_history = []
    epoch_losses = []

    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        batch_kl = []
        batch_losses = []
        avg_kl = None
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                policy_targets = torch.stack([ex[1] for ex in batch]).to(device)
                value_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                with torch.no_grad():
                    ref_policy_logits, _ = reference_model(boards)
                    ref_policy_probs = F.softmax(ref_policy_logits, dim=1)

                policy_probs = F.log_softmax(policy_logits, dim=1)
                kl_div = F.kl_div(policy_probs, ref_policy_probs, reduction='batchmean')
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss_mcts(policy_logits, value_preds, policy_targets, value_targets, kl_div, epoch)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                batch_losses.append(loss.item())
                batch_kl.append(kl_div.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if batch_kl:
            avg_kl = np.mean(batch_kl)
            kl_history.append(avg_kl)
        if batch_losses:
            avg_loss = np.mean(batch_losses)
            epoch_losses.append(avg_loss)
            if avg_kl is not None:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL={avg_kl:.4f}")
            else:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL=N/A")
        else:
            print(f"Epoch {epoch+1}: No valid batches")

    # Plot KL-divergence history
    plt.figure()
    plt.plot(kl_history, label="KL-divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL-divergence")
    plt.title("KL-divergence vs Epoch")
    plt.legend()
    plt.show()

In [None]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model.requires_grad_(False)
# model, _ = RL_utils.load_resnet_model()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# # Main training loop
# for iteration in range(5):  # 5 outer iterations of self-play + training
#     print(f"\n=== Iteration {iteration+1} ===")
    
#     # Generate self-play games
#     self_play_data = generate_self_play_data(model, device, positions)
    
#     # Train model on self-play games
#     train_on_mcts(reference_model, model, optimizer, self_play_data, device)

#     # Save model
#     torch.save({
#         'model_state_dict': model.state_dict(),
#     }, MODEL_SAVE_PATH)

Model loaded from models/TORCH_250EPOCH_DoubleHead.pth

=== Iteration 1 ===


# Updated Code with new style training

In [None]:
from mcts import SimpleMCTS
from chess_utils import move_to_policy_index, board_to_tensor
from RL_utils import blend_policies, policy_index_to_move, get_game_result

def generate_self_play_data_with_blending(model, reference_model, device, start_positions, num_games=100, N_OPENING_MOVES=10):
    data = []
    for i in range(num_games):
        board = random.choice(start_positions)
        mcts = SimpleMCTS(model, device, num_simulations=400)
        game_history = []
        move_count = 0
        while not board.is_game_over() and move_count < MAX_GAME_MOVES:
            temperature = 1.0 if move_count < N_OPENING_MOVES else 0.1
            action_probs, root = mcts.get_action_probs(board, temperature)
            # Convert MCTS action_probs dict to vector
            mcts_policy = np.zeros(4288)
            for move, prob in action_probs.items():
                idx = move_to_policy_index(move)
                mcts_policy[idx] = prob

            # Get CNN policy
            with torch.no_grad():
                board_tensor = board_to_tensor(board)
                input_tensor = torch.tensor(board_tensor, dtype=torch.float32).unsqueeze(0).to(device)
                cnn_logits, _ = reference_model(input_tensor)
                cnn_policy = torch.softmax(cnn_logits, dim=1).cpu().numpy()[0]
                # Mask illegal moves
                legal_indices = [move_to_policy_index(move) for move in board.legal_moves]
                mask = np.zeros_like(cnn_policy)
                mask[legal_indices] = 1
                cnn_policy *= mask
                cnn_policy /= cnn_policy.sum() if cnn_policy.sum() > 0 else 1

            # Blend for opening moves
            if move_count < N_OPENING_MOVES:
                blended_policy = blend_policies(mcts_policy, cnn_policy, alpha=0.7)
            else:
                blended_policy = mcts_policy

            # Select move from blended policy
            move_idx = np.random.choice(len(blended_policy), p=blended_policy)
            move = policy_index_to_move(move_idx, board)
            board.push(move)

            # Store training data
            board_tensor = torch.tensor(board_to_tensor(board), dtype=torch.float32).unsqueeze(0)
            game_history.append((board_tensor, blended_policy, board.turn))
            move_count += 1

        # Get game result
        result = get_game_result(board)
        data.append((game_history, result))
    return data

In [None]:
reference_model, device = RL_utils.load_resnet_model(ref_model_path)
reference_model.requires_grad_(False)
reference_model.eval()
model, _ = RL_utils.load_resnet_model()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for iteration in range(5):
    print(f"\n=== Iteration {iteration+1} ===")
    # Generate self-play games with policy blending
    self_play_data = generate_self_play_data_with_blending(model, reference_model, device, positions, num_games=NUM_SELF_PLAY_GAMES, N_OPENING_MOVES=10)
    # Train model on self-play games (with KL loss for opening positions)
    train_on_mcts(reference_model, model, optimizer, self_play_data, device)
    torch.save({'model_state_dict': model.state_dict()}, MODEL_SAVE_PATH)