# REINFORCE

In [1]:
# system
import sys 
sys.path.insert(0, '../../../')

# utils
import numpy as np
import torch  
import logging
from tqdm import tqdm 
from chessrl.utils.load_config import load_config
from chessrl.utils.fen_parsing import parse_fen_cached
from typing import List, Tuple, Dict


import os
config_path = os.path.join('./', 'config.json')
config = load_config(config_path)
logging.basicConfig(level=config['log_level'], format = '%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# chess
from chessrl import Env, SyzygyDefender
from chessrl import chess_py as cp
from chessrl.algorithms.policy_gradient.policy import Policy
from chessrl.utils.move_idx import build_move_mappings

2025-08-28 10:18:11,611 - INFO - Loading config file...
2025-08-28 10:18:11,853 - INFO - Loading config file...


In [2]:
# system
import sys 
sys.path.insert(0, '../../')

# utils
import logging
import os
import argparse
import numpy as np
from chessrl.utils.load_config import load_config
from chessrl.utils.endgame_loader import load_positions, get_stats, get_all_endgames_from_dtz

# Import the optimized REINFORCE
from chessrl.algorithms.policy_gradient.reinforce import REINFORCE

2025-08-28 10:18:11,860 - INFO - Loading config file...


In [3]:
positions, dtz_groups = load_positions(csv_path=config['csv_path'])
endgames = [pos['fen'] for pos in positions]
train_endgames = np.random.choice(endgames, size=config['n_episodes'], replace=True).tolist()

In [4]:
agent = REINFORCE()

2025-08-28 10:18:12,152 - INFO - Using MPS (Apple Silicon GPU)


### Train

In [14]:
n_episodes = len(train_endgames)
n_batches = (n_episodes + agent.batch_size - 1) // agent.batch_size
all_losses = []
all_dtms = []

# batch 0 (0-31)
start_idx = 5 * agent.batch_size
end_idx = min(start_idx + agent.batch_size, n_episodes)
batch_endgames = train_endgames[start_idx:end_idx]

envs = [Env.from_fen(fen, defender=agent.defender) for fen in batch_endgames]

episodes_data = agent.sample_episodes(envs)

loss, dtms = agent.train_batch(episodes_data)

In [15]:
loss

-0.0344996340572834

In [16]:
dtms

[]

### Sample episode

In [6]:
move_to_idx, idx_to_move = build_move_mappings()

In [7]:
def get_legal_move_indices(env):
    """
    Get indices of legal moves for the current position.
    """
    legal_moves_idx = []
    
    for move in env.state().legal_moves(cp.Color.WHITE):
        move_str = cp.Move.to_uci(move)[:4]
        if move_str in move_to_idx:
            legal_moves_idx.append(move_to_idx[move_str])
    return legal_moves_idx

In [8]:
batch_size = len(envs)
episodes_data = [{'states': [], 'actions': [], 'rewards': [], 'action_indices': [], 'done': False} 
                for _ in range(batch_size)]

In [9]:
batch_fens = []
batch_legal_moves = []
active_indices = []

In [10]:
for i, env in enumerate(envs):
    if not episodes_data[i]['done']:
        current_fen = env.to_fen()
        legal_moves = get_legal_move_indices(env)
        
        if not legal_moves or env.state().is_game_over():
            episodes_data[i]['done'] = True
            continue
        
        batch_fens.append(current_fen)
        batch_legal_moves.append(legal_moves)
        active_indices.append(i)

In [11]:
env.state()

<Game fen="8/1k6/4R3/8/3K4/8/8/8 w - - 0 1">

In [12]:
for step in range(config['max_steps']):
    # Collect states and legal moves for active episodes
    batch_fens = []
    batch_legal_moves = []
    active_indices = []
    
    for i, env in enumerate(envs):
        if not episodes_data[i]['done']:
            current_fen = env.to_fen()
            legal_moves = get_legal_move_indices(env) # idx of actions
            
            if not legal_moves or env.state().is_game_over():
                episodes_data[i]['done'] = True
                continue
            
            batch_fens.append(current_fen)
            batch_legal_moves.append(legal_moves)
            active_indices.append(i)
    
    if not active_indices:
        break  # All episodes are done
    
    # Batch process active episodes
    if batch_fens:
        # Parse all FENs and stack into batch tensor
        fen_tensors = torch.stack([
            parse_fen_cached(fen, agent.fen_cache) for fen in batch_fens
        ]).to(agent.device)
        
        # Get action probabilities for entire batch
        with torch.no_grad():
            # Forward pass for entire batch
            batch_logits = agent.policy.forward(fen_tensors)
            
            # Process each episode in the batch
            for batch_idx, env_idx in enumerate(active_indices):
                legal_moves = batch_legal_moves[batch_idx]
                logits = batch_logits[batch_idx]
                
                # Create mask for legal moves
                mask = torch.zeros(4096, device=agent.device)
                mask[legal_moves] = 1
                
                # Apply mask and get probabilities
                masked_logits = logits.masked_fill(mask == 0, float('-inf'))
                action_probs = torch.softmax(masked_logits, dim=-1)
                log_probs = torch.log_softmax(masked_logits, dim=-1)
                
                # Sample action
                legal_probs = action_probs[legal_moves].cpu().numpy()
                legal_probs = legal_probs + 1e-8 # for numerical stability
                legal_probs = legal_probs / legal_probs.sum()
                
                selected_legal_idx = np.random.choice(len(legal_moves), p=legal_probs) # alternative to epsilon-greedy method used in other algorithms
                action_idx = legal_moves[selected_legal_idx]
                
                # Convert to move and take step
                move_str = idx_to_move[action_idx]
                move = cp.Move.from_strings(envs[env_idx].state(), 
                                            move_str[:2], move_str[2:4])
                step_result = envs[env_idx].step(move)
                
                # Store episode data
                episodes_data[env_idx]['states'].append(batch_fens[batch_idx])
                episodes_data[env_idx]['actions'].append(move_str)
                episodes_data[env_idx]['rewards'].append(step_result.reward)
                episodes_data[env_idx]['action_indices'].append(action_idx)
                
                if envs[env_idx].state().is_game_over():
                    episodes_data[env_idx]['done'] = True
                    if envs[env_idx].state().is_checkmate():
                        dtm = 2 * (len(episodes_data[env_idx]['states']) - 0.5)
                        agent.dtm_history.append(dtm) # Store DTM for later analysis


In [17]:
episodes_data[4]

{'states': ['3R4/8/8/8/1K2k3/8/8/8 w - - 0 1',
  '3R4/8/8/2K2k2/8/8/8/8 w - - 0 1',
  '3R4/8/2K5/4k3/8/8/8/8 w - - 0 1',
  '8/8/2K5/3R4/5k2/8/8/8 w - - 0 1',
  '8/2K5/8/3R4/6k1/8/8/8 w - - 0 1',
  '8/2K5/8/5k2/8/8/3R4/8 w - - 0 1',
  '8/2K5/8/3R4/6k1/8/8/8 w - - 0 1',
  '3K4/8/8/3R4/5k2/8/8/8 w - - 0 1',
  '3K4/8/8/5k2/8/8/8/3R4 w - - 0 1',
  '8/2K5/4k3/8/8/8/8/3R4 w - - 0 1',
  '8/2K5/3R4/4k3/8/8/8/8 w - - 0 1',
  '8/3K4/3R4/5k2/8/8/8/8 w - - 0 1',
  '3K4/8/3R4/8/4k3/8/8/8 w - - 0 1',
  '8/3K4/3R4/5k2/8/8/8/8 w - - 0 1',
  '8/8/2KR4/4k3/8/8/8/8 w - - 0 1',
  '8/8/3R4/2K2k2/8/8/8/8 w - - 0 1',
  '8/8/R7/2K1k3/8/8/8/8 w - - 0 1',
  'R7/8/4k3/2K5/8/8/8/8 w - - 0 1',
  '5R2/8/8/2K1k3/8/8/8/8 w - - 0 1',
  '8/8/8/2K5/4k3/5R2/8/8 w - - 0 1'],
 'actions': ['b4c5',
  'c5c6',
  'd8d5',
  'c6c7',
  'd5d2',
  'd2d5',
  'c7d8',
  'd5d1',
  'd8c7',
  'd1d6',
  'c7d7',
  'd7d8',
  'd8d7',
  'd7c6',
  'c6c5',
  'd6a6',
  'a6a8',
  'a8f8',
  'f8f3',
  'c5b5'],
 'rewards': [-2.0,
  -2.0,
  -2.0,
  -2.

In [18]:
positive_episodes = []
for i in range(len(episodes_data)):
    if episodes_data[i]['rewards'][-1] == 100:
        positive_episodes.append(episodes_data[i])

In [19]:
positive_episodes

[]

### Train batch

In [20]:
all_rewards = [ep['rewards'] for ep in episodes_data]
all_returns = agent.calculate_returns(all_rewards)

In [21]:
batch_states = []
batch_actions = []
batch_returns = []
batch_legal_moves = []

In [22]:
for ep_idx, episode in enumerate(episodes_data):
    for t in range(len(episode['states'])):
        batch_states.append(episode['states'][t])
        batch_actions.append(episode['actions'][t])
        batch_returns.append(all_returns[ep_idx][t])
        
        # Get legal moves for this state
        temp_env = Env.from_fen(episode['states'][t], defender=agent.defender)
        legal_moves = get_legal_move_indices(temp_env)
        batch_legal_moves.append(legal_moves)

In [23]:
state_tensors = torch.stack([
    parse_fen_cached(fen, agent.fen_cache) for fen in batch_states
]).to(agent.device)

In [24]:
returns_tensor = torch.stack(batch_returns)
if len(returns_tensor) > 1:
    returns_tensor = (returns_tensor - returns_tensor.mean()) / (returns_tensor.std() + 1e-8)

In [25]:
logits = agent.policy.forward(state_tensors)

In [26]:
logits[0]

tensor([-0.9931, -1.0070,  0.7270,  ..., -0.6535,  0.3199,  0.4888],
       device='mps:0', grad_fn=<SelectBackward0>)

In [27]:
log_probs_list = []
for i, (action_str, legal_moves) in enumerate(zip(batch_actions, batch_legal_moves)):
    action_idx = move_to_idx[action_str]
    
    # Create mask for legal moves
    mask = torch.zeros(4096, device=agent.device)
    mask[legal_moves] = 1
    
    # Apply mask and get log probabilities
    masked_logits = logits[i].masked_fill(mask == 0, float('-inf'))
    log_probs = torch.log_softmax(masked_logits, dim=-1)
    log_probs_list.append(log_probs[action_idx])

In [28]:
log_probs_list

[tensor(-2.8636, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.8926, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.7002, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.7134, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.8886, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.9557, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.5587, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.6621, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.9352, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-3.0608, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.6568, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.8332, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-1.8782, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.9035, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.1686, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-2.8562, device='mps:0', grad_fn=<SelectBackward0>),
 tensor(-3.1236, device=

In [29]:
log_probs_tensor = torch.stack(log_probs_list)

In [30]:
loss = -torch.mean(log_probs_tensor * returns_tensor)

In [31]:
loss

tensor(-0.0213, device='mps:0', grad_fn=<NegBackward0>)

In [32]:
agent.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(agent.policy.parameters(), max_norm=1.0)
agent.optimizer.step()

In [36]:
dtms = []
for episode in episodes_data:
    if any('checkmate' in str(r) for r in episode['rewards'] if r > 0):
        dtm = 2 * (len(episode['states']) - 0.5)
        dtms.append(dtm)