In [1]:
import chess
import RL_utils

In [2]:
# Our existing CNN model
ref_model_path = "models/CNN_ResNet.pth"

## Load stockfish

In [3]:
from RL_utils import PositionEvaluator

ImportError: cannot import name 'PositionEvaluator' from 'RL_utils' (/home/smith3_j@WMGDS.WMG.WARWICK.AC.UK/Documents/jordi-lete/Chess-AI/RL_utils.py)

In [4]:
stockfish_path = "models/stockfish/stockfish-windows-x86-64-avx2.exe"
evaluator = PositionEvaluator(stockfish_path, elo_rating=1400)

Stockfish initialized with ELO: 1400


In [5]:
board = chess.Board()
board.push_uci("e2e4")
board.push_uci("h7h6")

Move.from_uci('h7h6')

In [6]:
eval_score = evaluator.evaluate_position(board)
print(f"Position evaluation: {eval_score}")

Position evaluation: 0.83


In [7]:
best_moves = evaluator.get_best_moves(board, 3)
print(f"Best moves: {best_moves}")

Best moves: [{'Move': 'd2d4', 'Centipawn': 83, 'Mate': None}, {'Move': 'g1f3', 'Centipawn': 71, 'Mate': None}, {'Move': 'b1c3', 'Centipawn': 68, 'Mate': None}]


## Load our CNN trained model

In [3]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model = reference_model.to(device)

In [4]:
# import torch
# dummy_input = torch.randn(1, 19, 8, 8).to(device)
# with torch.no_grad():
#     output = model(dummy_input)
#     print(f"Model output shape: {output.shape}")  # Should be [1, 4288]
#     print("Model loaded successfully!")

## Load the training data

In [5]:
pgn_file = "games/lichess_db_2016-04.pgn"
positions = []
positions += RL_utils.extract_varied_positions(pgn_file, num_positions=1000)
print(f"Total positions loaded: {len(positions)}")

Extracted 1000 varied positions (openings, early, middlegame)
Total positions loaded: 1000


In [6]:
# Test creating a batch
# if positions:
#     board_tensors, legal_masks, boards = RL_utils.create_training_batch(positions, batch_size=4)
#     print(f"Batch shapes:")
#     print(f"  Board tensors: {board_tensors.shape}")
#     print(f"  Legal masks: {legal_masks.shape}")
#     print(f"  Number of boards: {len(boards)}")
    
#     # Show a sample position
#     print(f"\nSample position FEN: {boards[0].fen()}")

In [7]:
# boards[0]
# eval_score = evaluator.evaluate_position(boards[0])
# print(f"Position evaluation: {eval_score}")

## Training Loop

In [8]:
import random
import tqdm
import torch
import numpy as np
import torch.nn.functional as F

In [9]:
# Training Hyperparameters
LEARNING_RATE = 0.0001
BATCH_SIZE = 64
EPOCHS = 2

# Self-play
NUM_SELF_PLAY_GAMES = 100
MAX_GAME_MOVES = 160
TEMPERATURE = 1.0

# Model Saving
MODEL_SAVE_PATH = "models/rlmodel"

In [10]:
def generate_self_play_data(model, device, start_positions, num_games=NUM_SELF_PLAY_GAMES):
    """Generate self-play data with error handling"""
    data = []
    successful_games = 0
    
    for i in range(num_games):
        try:
            board = random.choice(start_positions)
            # game_history, result = RL_utils.play_self_play_game(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            game_history, result = RL_utils.play_game_with_mcts(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            
            if len(game_history) > 0:  # Only add if we have valid data
                data.append((game_history, result))
                successful_games += 1
                
        except Exception as e:
            print(f"Error in game {i}: {e}")
            continue
    
    print(f"Generated {successful_games} successful games out of {num_games} attempts")
    return data

In [11]:
def train_on_self_play(model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, legal_mask, move_idx, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            all_examples.append((board_tensor, legal_mask, move_idx, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        epoch_losses = []
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                masks = torch.stack([ex[1] for ex in batch]).to(device)
                move_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.long).to(device)
                value_targets = torch.tensor([ex[3] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss(policy_logits, value_preds, move_targets, value_targets, masks)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                epoch_losses.append(loss.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if epoch_losses:
            avg_loss = np.mean(epoch_losses)
            print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}")
        else:
            print(f"Epoch {epoch+1}: No valid batches")


In [12]:
import matplotlib.pyplot as plt

def train_on_mcts(reference_model, model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, policy_vector, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            policy_tensor = torch.tensor(policy_vector, dtype=torch.float32)
            all_examples.append((board_tensor, policy_tensor, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    kl_history = []
    epoch_losses = []

    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        batch_kl = []
        batch_losses = []
        avg_kl = None
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                policy_targets = torch.stack([ex[1] for ex in batch]).to(device)
                value_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                with torch.no_grad():
                    ref_policy_logits, _ = reference_model(boards)
                    ref_policy_probs = F.softmax(ref_policy_logits, dim=1)

                policy_probs = F.log_softmax(policy_logits, dim=1)
                kl_div = F.kl_div(policy_probs, ref_policy_probs, reduction='batchmean')
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss_mcts(policy_logits, value_preds, policy_targets, value_targets, kl_div, epoch)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                batch_losses.append(loss.item())
                batch_kl.append(kl_div.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if batch_kl:
            avg_kl = np.mean(batch_kl)
            kl_history.append(avg_kl)
        if batch_losses:
            avg_loss = np.mean(batch_losses)
            epoch_losses.append(avg_loss)
            if avg_kl is not None:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL={avg_kl:.4f}")
            else:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL=N/A")
        else:
            print(f"Epoch {epoch+1}: No valid batches")

    # Plot KL-divergence history
    plt.figure()
    plt.plot(kl_history, label="KL-divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL-divergence")
    plt.title("KL-divergence vs Epoch")
    plt.legend()
    plt.show()

In [13]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model.requires_grad_(False)
# model, _ = RL_utils.load_resnet_model()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# # Main training loop
# for iteration in range(5):  # 5 outer iterations of self-play + training
#     print(f"\n=== Iteration {iteration+1} ===")
    
#     # Generate self-play games
#     self_play_data = generate_self_play_data(model, device, positions)
    
#     # Train model on self-play games
#     train_on_mcts(reference_model, model, optimizer, self_play_data, device)

#     # Save model
#     torch.save({
#         'model_state_dict': model.state_dict(),
#     }, MODEL_SAVE_PATH)

# Updated Code with new style training

In [14]:
from mcts import SimpleMCTS
from chess_utils import move_to_policy_index, board_to_tensor
from RL_utils import blend_policies, policy_index_to_move, get_game_result

def generate_self_play_data_with_blending(model, reference_model, device, start_positions, num_games=100):
    data = []
    model.eval()
    reference_model.eval()
    for game_idx in range(num_games):
        board = random.choice(start_positions)
        mcts = SimpleMCTS(model, device, num_simulations=400)
        game_history = []
        move_count = 0
        print(f"\n[Game {game_idx+1}/{num_games}] Starting from position (fen={board.fen()})")

        while not board.is_game_over() and move_count < MAX_GAME_MOVES:
            total_ply = (board.fullmove_number - 1) * 2 + (0 if board.turn else 1) # total plies (black + white moves)
            if total_ply < 20:
                temperature = 0.3
                alpha = 0.3 # Mostly CNN
            elif total_ply < 50:
                temperature = 0.5
                alpha = 0.6
            else:
                temperature = 0.2
                alpha = 1.0 # Mostly mcts

            action_probs, root = mcts.get_action_probs(board, temperature)

            # Fallback: if MCTS failed (should be rare), use uniform over legals
            legal_moves = list(board.legal_moves)
            if not action_probs or len(legal_moves) == 0:
                mcts_move_probs = {m: 1.0 / max(1, len(legal_moves)) for m in legal_moves}
            else:
                # ensure we only carry probs for legal moves, in the same order
                mcts_move_probs = {m: action_probs.get(m, 0.0) for m in legal_moves}

            # Get CNN policy
            with torch.no_grad():
                board_tensor = board_to_tensor(board)
                input_tensor = torch.tensor(board_tensor, dtype=torch.float32).unsqueeze(0).to(device)
                cnn_logits, _ = reference_model(input_tensor)
                cnn_full = torch.softmax(cnn_logits, dim=1).squeeze(0).detach().cpu().numpy()

            # Extract the CNN prob for each legal move (via index mapping)
            cnn_move_probs = {}
            for m in legal_moves:
                idx = move_to_policy_index(m)
                p = cnn_full[idx] if idx is not None else 0.0
                cnn_move_probs[m] = float(p)
            # Normalize CNN mass only over legal moves
            s = sum(cnn_move_probs.values())
            if s > 0:
                for m in cnn_move_probs:
                    cnn_move_probs[m] /= s
            else:
                # If CNN assigns zero to all legal moves (mapping/mask oddity),
                # make it uniform over legals instead of all zeros.
                u = 1.0 / len(legal_moves)
                for m in cnn_move_probs:
                    cnn_move_probs[m] = u

            # --- Blend (opening only), then renormalize ---
            blended = {m: alpha * mcts_move_probs[m] + (1 - alpha) * cnn_move_probs[m] for m in legal_moves}

            # Renormalize blended (safety against numerical drift)
            bsum = sum(blended.values())
            if bsum <= 0 or not np.isfinite(bsum):
                # final safety: uniform over legals
                blended = {m: 1.0 / len(legal_moves) for m in legal_moves}
            else:
                for m in blended:
                    blended[m] /= bsum

            # --- Build 4288-dim training target vector from blended over legals ---
            policy_vector = np.zeros(4288, dtype=np.float32)
            for m, p in blended.items():
                idx = move_to_policy_index(m)
                if idx is not None:
                    policy_vector[idx] = p

            # --- STORE TRAINING DATA *BEFORE* MAKING THE MOVE ---
            board_tensor = torch.tensor(board_to_tensor(board), dtype=torch.float32).unsqueeze(0)
            current_turn = board.turn  # who is about to play now
            game_history.append((board_tensor, policy_vector, current_turn))

            # --- Sample move ONLY among legal moves using blended distribution ---
            moves = legal_moves
            probs = np.array([blended[m] for m in moves], dtype=np.float64)
            # Final normalization for np.random.choice requirements
            psum = probs.sum()
            if psum <= 0 or not np.isfinite(psum):
                probs[:] = 1.0 / len(probs)
            else:
                probs /= psum

            chosen = np.random.choice(len(moves), p=probs)
            move = moves[chosen]
            board.push(moves[chosen])
            move_count += 1

            if move_count == 1:
                entropy = -np.sum(probs * np.log(probs + 1e-12))
                print(f"   Move {move_count}: sampled {move}, "
                      f"probs min={probs.min():.4f}, max={probs.max():.4f}, "
                      f"entropy={entropy:.2f}")
            elif move_count % 20 == 0:
                entropy = -np.sum(probs * np.log(probs + 1e-12))
                print(f"   Move {move_count}: sampled {move}, "
                      f"probs min={probs.min():.4f}, max={probs.max():.4f}, "
                      f"entropy={entropy:.2f}, fen={board.fen()}")

        # Get game result
        result = get_game_result(board)
        data.append((game_history, result))
        print(f"[Game {game_idx+1}] Finished in {move_count} moves. Result: {result}")
        print(f"End position (fen={board.fen()})")
    return data

In [None]:
reference_model, device = RL_utils.load_resnet_model(ref_model_path)
reference_model.requires_grad_(False)
reference_model.eval()
model, _ = RL_utils.load_resnet_model()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

for iteration in range(5):
    print(f"\n=== Iteration {iteration+1} ===")
    # Generate self-play games with policy blending
    self_play_data = generate_self_play_data_with_blending(model, reference_model, device, positions, num_games=NUM_SELF_PLAY_GAMES)
    # Train model on self-play games (with KL loss for opening positions)
    train_on_mcts(reference_model, model, optimizer, self_play_data, device)
    save_name = MODEL_SAVE_PATH + "_" + iteration + ".pth"
    try:
        torch.save({'model_state_dict': model.state_dict()}, save_name)
    except:
        torch.save({'model_state_dict': model.state_dict()}, "rl_backup_save.pth")

Model loaded from models/CNN_ResNet.pth
Initialized new model.

=== Iteration 1 ===

[Game 1/100] Starting from position (fen=rnbqkb1r/pp3ppp/4pn2/8/3B4/2N2N2/PPP2PPP/R2QKB1R w KQkq - 0 8)
   Move 1: sampled f1d3, probs min=0.0009, max=0.4733, entropy=1.66
   Move 20: sampled c7d7, probs min=0.0073, max=0.4067, entropy=2.54, fen=4rb1r/1p1qkppp/pBn1p3/7n/1P6/2N1Q2P/P1P2PP1/R4RK1 w - - 0 18
   Move 40: sampled h8f8, probs min=0.0012, max=0.5383, entropy=1.62, fen=4rr2/1p2kp1p/pB2p1p1/6Pn/1P1n1P2/P1P4P/5R2/2Q1R1K1 w - - 1 28
   Move 60: sampled g8a8, probs min=0.0015, max=0.4187, entropy=1.74, fen=r7/4r2p/pp1kp1p1/6Pn/1P1B1P2/P1P4P/4RR2/6K1 w - - 0 38
   Move 80: sampled h7c7, probs min=0.0091, max=0.9176, entropy=0.39, fen=rBk5/2r5/p3p1p1/1pR3Pp/1P5P/P1P5/7K/1R2n3 w - - 4 48
   Move 100: sampled d3b4, probs min=0.0007, max=0.5581, entropy=0.85, fen=r1k5/B7/p5p1/1p1R2Pp/1n2p2P/P1P5/6K1/8 w - - 0 58
   Move 120: sampled a8a7, probs min=0.0018, max=0.6387, entropy=1.32, fen=1r6/k7/p5p1/6Pp/