In [1]:
import chess
import RL_utils

In [2]:
# Our existing CNN model
model_path = "models/TORCH_250EPOCH 1.pth"

## Load stockfish

In [3]:
from RL_utils import PositionEvaluator

In [4]:
stockfish_path = "models/stockfish/stockfish-windows-x86-64-avx2.exe"
evaluator = PositionEvaluator(stockfish_path, elo_rating=1400)

Stockfish initialized with ELO: 1400


In [6]:
board = chess.Board()
board.push_uci("e2e4")
board.push_uci("h7h6")

Move.from_uci('h7h6')

In [6]:
eval_score = evaluator.evaluate_position(board)
print(f"Position evaluation: {eval_score}")

Position evaluation: 0.83


In [7]:
best_moves = evaluator.get_best_moves(board, 3)
print(f"Best moves: {best_moves}")

Best moves: [{'Move': 'd2d4', 'Centipawn': 83, 'Mate': None}, {'Move': 'g1f3', 'Centipawn': 71, 'Mate': None}, {'Move': 'b1c3', 'Centipawn': 68, 'Mate': None}]


## Load our CNN trained model

In [5]:
model, device = RL_utils.load_model(model_path)
model = model.to(device)

In [6]:
import torch
dummy_input = torch.randn(1, 19, 8, 8).to(device)
with torch.no_grad():
    output = model(dummy_input)
    print(f"Model output shape: {output.shape}")  # Should be [1, 4288]
    print("Model loaded successfully!")

Model output shape: torch.Size([1, 4288])
Model loaded successfully!


## Load the training data

In [None]:
pgn_file = "games/lichess_db_2016-04.pgn"
positions = []
positions += RL_utils.extract_middlegame_positions(pgn_file, evaluator=None, num_positions=1000)
print(f"Total positions loaded: {len(positions)}")

In [None]:
# Test creating a batch
# if positions:
#     board_tensors, legal_masks, boards = RL_utils.create_training_batch(positions, batch_size=4)
#     print(f"Batch shapes:")
#     print(f"  Board tensors: {board_tensors.shape}")
#     print(f"  Legal masks: {legal_masks.shape}")
#     print(f"  Number of boards: {len(boards)}")
    
#     # Show a sample position
#     print(f"\nSample position FEN: {boards[0].fen()}")

Batch shapes:
  Board tensors: torch.Size([4, 19, 8, 8])
  Legal masks: torch.Size([4, 4288])
  Number of boards: 4

Sample position FEN: r3brk1/1p3ppp/p1p1pn2/2n5/2PN4/2N1P1P1/P3QPBP/R4RK1 w - - 4 17


In [None]:
# boards[0]
# eval_score = evaluator.evaluate_position(boards[0])
# print(f"Position evaluation: {eval_score}")

Position evaluation: 8.1


## Training Loop

In [10]:
import random
import tqdm

In [None]:
# Training Hyperparameters
LEARNING_RATE = 0.02
BATCH_SIZE = 64
EPOCHS = 5

# Self-play
NUM_SELF_PLAY_GAMES = 50
MAX_GAME_MOVES = 200
TEMPERATURE = 1.0

# Model Saving
MODEL_SAVE_PATH = "/models/dual_head_model.pt"

In [None]:
def generate_self_play_data(model, device, start_positions, num_games=NUM_SELF_PLAY_GAMES):
    data = []
    for i in range(num_games):
        board = random.choice(start_positions)
        game_history, result = RL_utils.play_self_play_game(model, device, board, MAX_GAME_MOVES, TEMPERATURE)
        data.append((game_history, result))
    return data

In [None]:
def train_on_self_play(model, optimizer, game_histories, device):
    model.train()
    random.shuffle(game_histories)

    # Flatten all examples from all games
    all_examples = []
    for history, result in game_histories:
        for board_tensor, legal_mask, move_idx, turn in history:
            value = result if turn else -result  # flip perspective
            all_examples.append((board_tensor, legal_mask, move_idx, value))

    # Create batches
    for epoch in range(EPOCHS):
        random.shuffle(all_examples)
        for i in range(0, len(all_examples), BATCH_SIZE):
            batch = all_examples[i:i+BATCH_SIZE]
            if len(batch) == 0:
                continue

            boards = torch.cat([ex[0] for ex in batch]).to(device)
            masks = torch.stack([ex[1] for ex in batch]).to(device)
            move_targets = torch.tensor([ex[2] for ex in batch], dtype=torch.long).to(device)
            value_targets = torch.tensor([ex[3] for ex in batch], dtype=torch.float).to(device)

            optimizer.zero_grad()
            policy_logits, value_preds = model(boards)
            loss, p_loss, v_loss = RL_utils.compute_loss(policy_logits, value_preds, move_targets, value_targets, masks)
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch+1}: total loss={loss.item():.4f}, policy={p_loss.item():.4f}, value={v_loss.item():.4f}")


In [None]:
model, device = RL_utils.load_model(model_path)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# Main training loop
for iteration in range(5):  # 5 outer iterations of self-play + training
    print(f"\n=== Iteration {iteration+1} ===")
    
    # Generate self-play games
    self_play_data = generate_self_play_data(model, device, positions)
    
    # Train model on self-play games
    train_on_self_play(model, optimizer, self_play_data, device)

    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
    }, MODEL_SAVE_PATH)