In [1]:
import chess
import RL_utils

In [2]:
# Our existing CNN model
ref_model_path = "models/CNN_ResNet.pth"

## Load stockfish

In [3]:
from RL_utils import PositionEvaluator

ImportError: cannot import name 'PositionEvaluator' from 'RL_utils' (/home/smith3_j@WMGDS.WMG.WARWICK.AC.UK/Documents/jordi-lete/Chess-AI/RL_utils.py)

In [4]:
stockfish_path = "models/stockfish/stockfish-windows-x86-64-avx2.exe"
evaluator = PositionEvaluator(stockfish_path, elo_rating=1400)

Stockfish initialized with ELO: 1400


In [5]:
board = chess.Board()
board.push_uci("e2e4")
board.push_uci("h7h6")

Move.from_uci('h7h6')

In [6]:
eval_score = evaluator.evaluate_position(board)
print(f"Position evaluation: {eval_score}")

Position evaluation: 0.83


In [7]:
best_moves = evaluator.get_best_moves(board, 3)
print(f"Best moves: {best_moves}")

Best moves: [{'Move': 'd2d4', 'Centipawn': 83, 'Mate': None}, {'Move': 'g1f3', 'Centipawn': 71, 'Mate': None}, {'Move': 'b1c3', 'Centipawn': 68, 'Mate': None}]


## Load our CNN trained model

In [3]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model = reference_model.to(device)

In [4]:
# import torch
# dummy_input = torch.randn(1, 19, 8, 8).to(device)
# with torch.no_grad():
#     output = model(dummy_input)
#     print(f"Model output shape: {output.shape}")  # Should be [1, 4288]
#     print("Model loaded successfully!")

## Load the training data

In [3]:
pgn_file = "games/lichess_db_2016-04.pgn"
positions = []
positions += RL_utils.extract_varied_positions(pgn_file, num_positions=1000)
print(f"Total positions loaded: {len(positions)}")

Extracted 1000 varied positions (openings, early, middlegame)
Total positions loaded: 1000


In [4]:
# Test creating a batch
# if positions:
#     board_tensors, legal_masks, boards = RL_utils.create_training_batch(positions, batch_size=4)
#     print(f"Batch shapes:")
#     print(f"  Board tensors: {board_tensors.shape}")
#     print(f"  Legal masks: {legal_masks.shape}")
#     print(f"  Number of boards: {len(boards)}")
    
#     # Show a sample position
#     print(f"\nSample position FEN: {boards[0].fen()}")

In [5]:
# boards[0]
# eval_score = evaluator.evaluate_position(boards[0])
# print(f"Position evaluation: {eval_score}")

## Training Loop

In [6]:
import random
import tqdm
import torch
import numpy as np
import torch.nn.functional as F

In [7]:
# Training Hyperparameters
LEARNING_RATE = 0.0001
BATCH_SIZE = 64
EPOCHS = 2

# Self-play
NUM_SELF_PLAY_GAMES = 100
MAX_GAME_MOVES = 512
TEMPERATURE = 1.0

# Model Saving
MODEL_SAVE_PATH = "models/rlmodel_buffer"

In [8]:
def generate_self_play_data(model, device, start_positions, num_games=NUM_SELF_PLAY_GAMES):
    """Generate self-play data with error handling"""
    data = []
    successful_games = 0
    
    for i in range(num_games):
        try:
            board = random.choice(start_positions)
            # game_history, result = RL_utils.play_self_play_game(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            game_history, result = RL_utils.play_game_with_mcts(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
            
            if len(game_history) > 0:  # Only add if we have valid data
                data.append((game_history, result))
                successful_games += 1
                
        except Exception as e:
            print(f"Error in game {i}: {e}")
            continue
    
    print(f"Generated {successful_games} successful games out of {num_games} attempts")
    return data

In [9]:
def train_on_self_play(model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, legal_mask, move_idx, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            all_examples.append((board_tensor, legal_mask, move_idx, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        epoch_losses = []
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                masks = torch.stack([ex[1] for ex in batch]).to(device)
                move_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.long).to(device)
                value_targets = torch.tensor([ex[3] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss(policy_logits, value_preds, move_targets, value_targets, masks)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                epoch_losses.append(loss.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if epoch_losses:
            avg_loss = np.mean(epoch_losses)
            print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}")
        else:
            print(f"Epoch {epoch+1}: No valid batches")


In [10]:
import matplotlib.pyplot as plt

def train_on_mcts(reference_model, model, optimizer, game_histories, device):
    """Train model on self-play data with gradient clipping"""
    model.train()
    
    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, policy_vector, turn in history:
            # Convert result to value from current player's perspective
            value = result if turn else -result
            policy_tensor = torch.tensor(policy_vector, dtype=torch.float32)
            all_examples.append((board_tensor, policy_tensor, value))
    
    if len(all_examples) == 0:
        print("No training examples available!")
        return
    
    print(f"Training on {len(all_examples)} examples")
    
    kl_history = []
    epoch_losses = []

    # Training loop
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        batch_kl = []
        batch_losses = []
        avg_kl = None
        
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue
            
            try:
                # Prepare batch
                boards = torch.cat([ex[0] for ex in batch]).to(device)
                policy_targets = torch.stack([ex[1] for ex in batch]).to(device)
                value_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.float).to(device)
                
                # Forward pass
                optimizer.zero_grad()
                policy_logits, value_preds = model(boards)
                with torch.no_grad():
                    ref_policy_logits, _ = reference_model(boards)
                    ref_policy_probs = F.softmax(ref_policy_logits, dim=1)

                policy_probs = F.log_softmax(policy_logits, dim=1)
                kl_div = F.kl_div(policy_probs, ref_policy_probs, reduction='batchmean')
                
                # Compute loss
                loss, p_loss, v_loss = RL_utils.compute_loss_mcts(policy_logits, value_preds, policy_targets, value_targets, kl_div, epoch)
                
                # Check for NaN
                if torch.isnan(loss):
                    print("Skipping batch due to NaN loss")
                    continue
                
                # Backward pass
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                batch_losses.append(loss.item())
                batch_kl.append(kl_div.item())
                
            except Exception as e:
                print(f"Error in batch {i//BATCH_SIZE}: {e}")
                continue
        
        if batch_kl:
            avg_kl = np.mean(batch_kl)
            kl_history.append(avg_kl)
        if batch_losses:
            avg_loss = np.mean(batch_losses)
            epoch_losses.append(avg_loss)
            if avg_kl is not None:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL={avg_kl:.4f}")
            else:
                print(f"Epoch {epoch+1}: avg loss={avg_loss:.4f}, avg KL=N/A")
        else:
            print(f"Epoch {epoch+1}: No valid batches")

    # Plot KL-divergence history
    plt.figure()
    plt.plot(kl_history, label="KL-divergence")
    plt.xlabel("Epoch")
    plt.ylabel("KL-divergence")
    plt.title("KL-divergence vs Epoch")
    plt.legend()
    plt.show()

In [11]:
from RL_utils import compute_loss_mcts

def train_on_buffer(model, reference_model, optimizer, replay_buffer, device,
                    examples_per_epoch=8000, epochs=2, batch_size=64,
                    kl_schedule_epoch_offset=0):
    """
    Train model using examples sampled uniformly from replay_buffer.
    - examples_per_epoch: how many training examples to sample (flattened) per epoch
    - epochs: how many epochs to run over freshly sampled examples each call
    - kl_schedule_epoch_offset: if you use epoch-dependent KL beta, pass offset for compute_loss
    """
    if len(replay_buffer) == 0:
        print("Replay buffer empty - skipping training")
        return

    model.train()
    for epoch in range(epochs):
        # sample examples_for_epoch fresh each epoch (no repeat guarantee)
        examples = replay_buffer.sample_examples(examples_per_epoch)
        if len(examples) == 0:
            print("No examples sampled - skipping epoch")
            continue

        # shuffle examples
        random.shuffle(examples)
        epoch_losses = []
        policy_loss_vals = []
        value_loss_vals = []
        kl_vals = []
        kl_contrib_vals = []
        for i in range(0, len(examples), batch_size):
            batch = examples[i:i+batch_size]
            if len(batch) == 0:
                continue

            # Prepare batch
            boards = torch.cat([ex[0] for ex in batch]).to(device)           # shape [B, C, H, W]
            policy_targets = torch.stack([torch.tensor(ex[1], dtype=torch.float32) for ex in batch]).to(device)
            value_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.float32).to(device)

            optimizer.zero_grad()
            policy_logits, value_preds = model(boards)

            # reference policy for KL (detached)
            with torch.no_grad():
                ref_policy_logits, _ = reference_model(boards)
                ref_policy_probs = F.softmax(ref_policy_logits, dim=1)

            # KL direction: KL(model || ref) means expectation under model of log(model/ref)
            # using PyTorch F.kl_div which expects log-probs and targets probs: kl_div(log_p, q)
            policy_log_probs = F.log_softmax(policy_logits, dim=1)
            kl_div = F.kl_div(policy_log_probs, ref_policy_probs, reduction='batchmean')

            # compute loss - assumes compute_loss_mcts has signature:
            # compute_loss_mcts(policy_logits, value_preds, policy_targets, value_targets, kl_div, epoch_index)
            loss, p_loss, v_loss, kl_beta, kl_div_tensor = compute_loss_mcts(policy_logits, value_preds, policy_targets, value_targets, kl_div, epoch + kl_schedule_epoch_offset)

            if torch.isnan(loss) or torch.isinf(loss):
                print("Skipping batch due to NaN/Inf loss")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_losses.append(loss.item())
            policy_loss_vals.append(p_loss.item()) # should meaningfully decrease during training if the model is matching the blended MCTS/CNN soft targets.
            value_loss_vals.append(v_loss.item()) # tells us whether value prediction is improving. If it is small relative to policy, the network is focusing on policy.
            kl_vals.append(kl_div_tensor.item()) # tells us the divergence between ref and model.
            kl_contrib_vals.append(kl_beta * kl_div_tensor.item()) # avg KL*beta is how much it contributes to total_loss. If avg KL*beta is dominating, the model will be strongly forced to remain near the CNN.

        if epoch_losses:
            print(f"Train epoch {epoch+1}/{epochs}: avg total loss={np.mean(epoch_losses):.4f}, "
                  f"policy={np.mean(policy_loss_vals):.4f}, value={np.mean(value_loss_vals):.4f}, "
                  f"avg KL={np.mean(kl_vals):.4f}, avg KL*beta={np.mean(kl_contrib_vals):.4f}, "
                  f"examples={len(examples)}")
        else:
            print(f"Train epoch {epoch+1}/{epochs}: no valid batches")

In [12]:
# reference_model, device = RL_utils.load_model(ref_model_path)
# reference_model.requires_grad_(False)
# model, _ = RL_utils.load_resnet_model()
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# # Main training loop
# for iteration in range(5):  # 5 outer iterations of self-play + training
#     print(f"\n=== Iteration {iteration+1} ===")
    
#     # Generate self-play games
#     self_play_data = generate_self_play_data(model, device, positions)
    
#     # Train model on self-play games
#     train_on_mcts(reference_model, model, optimizer, self_play_data, device)

#     # Save model
#     torch.save({
#         'model_state_dict': model.state_dict(),
#     }, MODEL_SAVE_PATH)

# Updated Code with new style training

In [13]:
from mcts import SimpleMCTS
from chess_utils import move_to_policy_index, board_to_tensor
from RL_utils import policy_index_to_move, get_game_result

def generate_self_play_data_with_blending(
    model,
    reference_model,
    device,
    start_positions,
    num_games=100,
    num_simulations=400,
    max_game_moves=200,
    # blending / temperature schedule params
    opening_ply_cutoff=20,
    midgame_ply_cutoff=50,
    alpha_opening=0.2,
    alpha_mid=0.6,
    # resignation params (safe defaults)
    resign_threshold=-0.8,    # v_curr <= this is considered very losing
    resign_min_plies=20,       # don't resign before this many plies
    resign_consec=4            # require this many consecutive low-value preds
):
    """
    Generate self-play games with blended MCTS + CNN policies and resign logic.
    Returns list of (game_history, result) where game_history = [(board_tensor, policy_vector, turn), ...]
    Result is from WHITE's POV: +1 white win, -1 black win, 0 draw (or other fallback from get_game_result).
    """

    model.eval()
    reference_model.eval()
    data = []

    resign_count = 0

    for game_idx in range(num_games):
        # copy starting position so we don't mutate the pool
        board = random.choice(start_positions).copy()
        mcts = SimpleMCTS(model, device, num_simulations=num_simulations)
        game_history = []
        move_count = 0
        consec_low = 0
        early_ended = False

        print(f"\n[Game {game_idx+1}/{num_games}] Starting from position (fen={board.fen()})")

        while not board.is_game_over() and move_count < max_game_moves:
            total_ply = (board.fullmove_number - 1) * 2 + (0 if board.turn == chess.WHITE else 1)

            # --- Evaluate value head for resignation decision (use current RL model) ---
            with torch.no_grad():
                bt = torch.tensor(board_to_tensor(board), dtype=torch.float32, device=device).unsqueeze(0)
                _, v_white = model(bt)                     # shape [1,1]
                v_white = float(v_white.item())           # in [-1,1], white's POV

            # current-player perspective: +1 = good for player to move
            v_curr = v_white if board.turn == chess.WHITE else -v_white

            # consecutive low-value check
            if total_ply >= resign_min_plies and v_curr <= resign_threshold:
                consec_low += 1
            else:
                consec_low = 0

            if consec_low >= resign_consec:
                # current player resigns -> opponent wins
                result = -1.0 if board.turn == chess.WHITE else 1.0  # white POV
                resign_count += 1
                print(f"[Game {game_idx+1}] Resignation by {'White' if board.turn==chess.WHITE else 'Black'} "
                      f"at ply {total_ply}, v_curr={v_curr:.3f}, consec_low={consec_low}")
                data.append((game_history, result))
                early_ended = True
                break

            # legal moves (quick check)
            legal_moves = list(board.legal_moves)
            if len(legal_moves) == 0:
                break  # terminal (shouldn't happen normally)

            # --- CNN (reference) policy for legal moves (compute FIRST so we can adapt blending/temperature) ---
            with torch.no_grad():
                bt_ref = torch.tensor(board_to_tensor(board), dtype=torch.float32, device=device).unsqueeze(0)
                cnn_logits, _ = reference_model(bt_ref)
                cnn_full = torch.softmax(cnn_logits, dim=1).squeeze(0).detach().cpu().numpy()

            # Extract CNN probs for legal moves
            cnn_move_probs = {}
            for m in legal_moves:
                idx = move_to_policy_index(m)
                p = float(cnn_full[idx]) if idx is not None else 0.0
                cnn_move_probs[m] = p

            # Normalize CNN mass only over legal moves (safety)
            s = sum(cnn_move_probs.values())
            if s > 0:
                for m in cnn_move_probs:
                    cnn_move_probs[m] /= s
            else:
                u = 1.0 / len(legal_moves)
                for m in cnn_move_probs:
                    cnn_move_probs[m] = u

            # --- Compute CNN uncertainty (entropy) over legal moves ---
            eps = 1e-12
            legal_probs = np.array([cnn_move_probs[m] for m in legal_moves], dtype=np.float64)
            lp_sum = legal_probs.sum()
            if lp_sum > 0:
                legal_probs /= lp_sum
            else:
                legal_probs[:] = 1.0 / len(legal_probs)

            entropy = -np.sum(legal_probs * np.log(legal_probs + eps))
            max_entropy = np.log(len(legal_probs)) if len(legal_probs) > 0 else 1.0
            norm_entropy = float(entropy / max_entropy) if max_entropy > 0 else 1.0  # 0=confident,1=uncertain
            cnn_pmax = float(legal_probs.max())

            # --- schedule base_alpha and base_temperature from ply (your existing schedule) ---
            if total_ply < opening_ply_cutoff:
                base_temperature = 0.8
                base_alpha = alpha_opening
            elif total_ply < midgame_ply_cutoff:
                base_temperature = 0.7
                base_alpha = alpha_mid
            else:
                base_temperature = 0.2
                base_alpha = 1.0

            # --- adapt alpha: move from base_alpha toward 1.0 proportional to CNN uncertainty
            alpha = base_alpha + norm_entropy * (1.0 - base_alpha)
            alpha = float(np.clip(alpha, 0.0, 1.0))

            # --- adapt temperature: increase when CNN uncertain (so we explore more) ---
            temp_max = 1.0
            temperature = float(base_temperature + norm_entropy * (temp_max - base_temperature))
            temperature = max(0.01, min(temperature, temp_max))

            pure_cnn_pmax_thresh   = 0.90
            pure_cnn_entropy_thresh = 0.06   # normalized entropy threshold (0..1)
            pure_cnn_value_agree_thresh = -0.6  # require model value not strongly negative

            # compute whether we can trust CNN-only (opening, confident, and low entropy)
            candidate_pure_cnn = (cnn_pmax >= pure_cnn_pmax_thresh and norm_entropy <= pure_cnn_entropy_thresh and total_ply < opening_ply_cutoff)

            v_curr_rl = v_white if board.turn == chess.WHITE else -v_white
            model_agrees = (v_curr_rl >= pure_cnn_value_agree_thresh)

            # if CNN very confident in opening, use CNN-only policy
            use_pure_cnn = candidate_pure_cnn and model_agrees

            # --- Now run MCTS using the (possibly adapted) temperature when converting visits -> probs ---
            action_probs, root = mcts.get_action_probs(board, temperature)

            # Fallback MCTS uniform if needed
            if not action_probs:
                mcts_move_probs = {m: 1.0 / len(legal_moves) for m in legal_moves}
            else:
                mcts_move_probs = {m: action_probs.get(m, 0.0) for m in legal_moves}

            # --- Blend (cnn vs mcts) using adaptive alpha (or pure CNN if flagged) ---
            if use_pure_cnn:
                blended = {m: cnn_move_probs[m] for m in legal_moves}
            else:
                blended = {m: alpha * mcts_move_probs.get(m, 0.0) + (1 - alpha) * cnn_move_probs.get(m, 0.0)
                           for m in legal_moves}

            # Renormalize blended (safety)
            bsum = sum(blended.values())
            if bsum <= 0 or not np.isfinite(bsum):
                blended = {m: 1.0 / len(legal_moves) for m in legal_moves}
            else:
                for m in blended:
                    blended[m] /= bsum

            # Build policy vector (4288-dim)
            policy_vector = np.zeros(4288, dtype=np.float32)
            for m, p in blended.items():
                idx = move_to_policy_index(m)
                if idx is not None:
                    policy_vector[idx] = p

            # STORE training example (board BEFORE move)
            board_tensor = torch.tensor(board_to_tensor(board), dtype=torch.float32).unsqueeze(0)  # CPU ok
            current_turn = board.turn
            game_history.append((board_tensor, policy_vector, current_turn))

            # Sample and make the move among legal moves
            moves = legal_moves
            probs = np.array([blended[m] for m in moves], dtype=np.float64)
            psum = probs.sum()
            if psum <= 0 or not np.isfinite(psum):
                probs[:] = 1.0 / len(probs)
            else:
                probs /= psum

            chosen_idx = np.random.choice(len(moves), p=probs)
            chosen_move = moves[chosen_idx]
            board.push(chosen_move)
            move_count += 1

            # periodic logging (now includes entropy/alpha/temp)
            if move_count == 1 or move_count % 20 == 0:
                entropy_val = -np.sum(probs * np.log(probs + 1e-12))
                print(f"   Move {move_count}: sampled {chosen_move}, "
                      f"probs min={probs.min():.4f}, max={probs.max():.4f}, entropy={entropy_val:.2f}, fen={board.fen()}")
                print(f"      cnn_pmax={cnn_pmax:.3f}, norm_entropy={norm_entropy:.3f}, alpha={alpha:.3f}, temperature={temperature:.3f}, use_pure_cnn={use_pure_cnn}")

        # only append final result if we didn't already append due to early resignation
        if not early_ended:
            result = get_game_result(board)
            data.append((game_history, result))
            print(f"[Game {game_idx+1}] Finished in {move_count} moves. Result: {result}")
            print(f"End position (fen={board.fen()})")

    print(f"Total resignations in this batch: {resign_count}/{num_games}")
    return data

In [None]:
reference_model, device = RL_utils.load_resnet_model(ref_model_path)
reference_model.requires_grad_(False)
reference_model.eval()
model, _ = RL_utils.load_resnet_model(ref_model_path)

# if ref_model_path:
#     ckpt = torch.load(ref_model_path, map_location=device)
#     model.load_state_dict(ckpt['model_state_dict'])   # warm-start weights
#     print("RL model initialized from CNN checkpoint")

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

from replay_buffer import ReplayBuffer

replay_buffer = ReplayBuffer(capacity=12000)

for iteration in range(5):
    print(f"\n=== Iteration {iteration+1} ===")
    # Generate self-play games with policy blending
    self_play_data = generate_self_play_data_with_blending(model, reference_model, device, positions, num_games=NUM_SELF_PLAY_GAMES)
    replay_buffer.add_games(self_play_data)
    # train using samples from buffer
    train_on_buffer(model, reference_model, optimizer, replay_buffer,
                    device,
                    examples_per_epoch=8000,
                    epochs=2,
                    batch_size=BATCH_SIZE)
    # Train model on self-play games (with KL loss for opening positions)
    # train_on_mcts(reference_model, model, optimizer, self_play_data, device)
    try:
        save_name = MODEL_SAVE_PATH + "_" + str(iteration) + ".pth"
        torch.save({'model_state_dict': model.state_dict()}, save_name)
    except:
        torch.save({'model_state_dict': model.state_dict()}, "rl_backup_save.pth")

Model loaded from models/CNN_ResNet.pth
Model loaded from models/CNN_ResNet.pth

=== Iteration 1 ===

[Game 1/100] Starting from position (fen=rnbqk1nr/pp3ppp/4p3/2b5/4p3/8/PPPN1PPP/RNBQKB1R w KQkq - 0 6)
   Move 1: sampled d2e4, probs min=0.0000, max=1.0000, entropy=0.00, fen=rnbqk1nr/pp3ppp/4p3/2b5/4N3/8/PPP2PPP/RNBQKB1R b KQkq - 0 6
      cnn_pmax=1.000, norm_entropy=0.000, alpha=0.200, temperature=0.800, use_pure_cnn=True
   Move 20: sampled b7b5, probs min=0.0002, max=0.8295, entropy=0.68, fen=r1b2rk1/4nppp/4p3/1pN2n2/2p5/2P5/PP2KPPP/R1B4R w - - 0 16
      cnn_pmax=0.924, norm_entropy=0.081, alpha=0.632, temperature=0.724, use_pure_cnn=False
   Move 40: sampled b5c4, probs min=0.0002, max=0.6931, entropy=1.04, fen=5rk1/1b3npp/8/P2npp2/2p5/2P3B1/R4PPP/3K3R w - - 0 26
      cnn_pmax=1.000, norm_entropy=0.000, alpha=0.600, temperature=0.700, use_pure_cnn=False
   Move 60: sampled e8e7, probs min=0.0000, max=0.9990, entropy=0.01, fen=6k1/4r1p1/P7/7p/2p5/5P2/K4P1P/2R5 w - - 0 36
      