In [1]:
import os
os.chdir("../../..")

print(os.getcwd())

/Users/eohjelle/Documents/2025-mcts-playground/mcts-playground


In [2]:
from core import ReplayBuffer
from experiments.connect_four.tensor_mapping import ConnectFourTensorMapping

buffer = ReplayBuffer.from_file("checkpoints/connect_four/ResMLP:5-32-128:fixed-rewards/buffer.pt")

TM = ConnectFourTensorMapping()

In [3]:
# states, targets, extra_data = buffer.states, buffer.targets, buffer.extra_data


# print(f"State shape: {states.shape}")
# print(f"Target keys: {targets.keys()}")
# print(f"Extra data keys: {extra_data.keys()}")

# tuples = list(zip(states, targets['policy'], targets['value'], extra_data['legal_actions']))

# start = 17940
# for i, (state, policy, value, legal_actions) in enumerate(tuples[start:17946]):
#     print(f"Index: {i + start}")
#     print(f"State: {state}")
#     print(f"Policy: {policy}")
#     print(f"Value: {value}")
#     print(f"Legal actions: {legal_actions}")
#     print("-" * 100)

In [4]:
# import torch


# policy_nans, value_nans, policy_invalids, legal_actions_nans = [], [], [], []

# # Find and examine first few NaN policies
# for i, (state, policy, value, legal_actions) in enumerate(tuples[:17946]):
#     show = False
#     if torch.isnan(policy).any():
#         print(f"Found NaN policy at index {i}")
#         policy_nans.append(i)
#         show = True
#     if not torch.isclose(policy.sum(), torch.tensor(1.0)):
#         print(f"Found invalid policy at index {i}") 
#         policy_invalids.append(i)
#         show = True
#     if torch.isnan(value):
#         print(f"Found NaN value at index {i}")
#         value_nans.append(i)
#         show = True
#     if torch.isnan(legal_actions).any():
#         print(f"Found NaN legal actions at index {i}")
#         legal_actions_nans.append(i)
#         show = True
#     if show:
#         print(f"Index {i}:")
#         print(f"State shape: {state.shape}")
#         print(f"Policy: {policy}")
#         print(f"Value: {value}")
#         print(f"Legal actions: {legal_actions}")
#         print("=" * 80)

# print(f"Policy nans: {len(policy_nans)}")
# print(f"Value nans: {len(value_nans)}")
# print(f"Policy invalids: {len(policy_invalids)}")
# print(f"Legal actions nans: {len(legal_actions_nans)}")

In [5]:
# import torch


# policy_zeros, value_nans, policy_invalids, legal_actions_nans = [], [], [], []

# # Find and examine first few NaN policies
# for i, (state, policy, value, legal_actions) in enumerate(tuples[:17946]):
#     show = False
#     for j in range(policy.shape[0]):
#         if torch.isclose(policy[j], torch.tensor(0.0)):
#             # print(f"Found zero value at index {i}, column {j}")
#             policy_zeros.append(i)
#             show = True
#             break
#     if torch.isnan(value):
#         # print(f"Found NaN value at index {i}")
#         value_nans.append(i)
#         show = True
#     # if show:
#     #     print(f"Index {i}:")
#     #     print(f"Policy: {policy}")
#     #     print(f"Value: {value}")
#     #     print("=" * 80)

# print(f"Policy zeros: {len(policy_zeros)}")
# print(f"Value nans: {len(value_nans)}")

In [6]:
import torch
import pyspiel
import logging
import numpy as np
from typing import Dict, List

# Import your modules
from core.games.open_spiel_state_wrapper import OpenSpielState
from core.algorithms.AlphaZero import AlphaZero, AlphaZeroConfig
from core.algorithms.MCTS import MCTS, MCTSConfig
from core.model_interface import Model, ModelPredictor
from core.simulation import generate_trajectories
from experiments.connect_four.tensor_mapping import ConnectFourTensorMapping
from experiments.connect_four.models.resmlp import ResMLP, ResMLPInitParams
from core.algorithms.AlphaZero import AlphaZeroTrainingAdapter

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Create initial state function
def create_initial_state():
    game = pyspiel.load_game("connect_four")
    return OpenSpielState(game.new_initial_state(), num_players=2)

# Initialize model (adjust parameters as needed)
model_params: ResMLPInitParams = ResMLPInitParams(
    input_dim=84,  # 2 * 6 * 7 = 84 for Connect Four
    num_residual_blocks=2,
    residual_dim=32,
    hidden_size=128,
    policy_head_dim=7
)

model = Model(
    model_architecture=ResMLP,
    init_params=model_params,
    device=device
)

# Create tensor mapping and training adapter
tensor_mapping = ConnectFourTensorMapping()
training_adapter = AlphaZeroTrainingAdapter()

# Create model predictor
model_predictor = ModelPredictor(model, tensor_mapping)

# Create AlphaZero agent
alphazero_config = AlphaZeroConfig(
    num_simulations=50,  # Small number for testing
)

print("=== STEP 1: Generate Trajectories ===")
try:
    trajectories = generate_trajectories(
        initial_state_creator=create_initial_state,
        trajectory_player_creators=[
            lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config),
            lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config)
        ],
        opponent_creators=[],
        num_games=3,  # Generate a few games
        logger=logging.getLogger(__name__)
    )
    print(f"Generated {len(trajectories)} trajectories")
    
    # Check trajectory outcomes for NaN
    for i, trajectory in enumerate(trajectories):
        final_reward = trajectory[-1].reward
        print(f"Trajectory {i} final reward: {final_reward}")
        if np.isnan(final_reward):
            print(f"❌ NaN found in trajectory {i} final reward!")
        else:
            print(f"✅ Trajectory {i} final reward is valid")
            
except Exception as e:
    print(f"❌ Error generating trajectories: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 2: Extract Examples ===")
examples = []
try:
    for i, trajectory in enumerate(trajectories):
        traj_examples = training_adapter.extract_examples(trajectory)
        examples.extend(traj_examples)
        print(f"Extracted {len(traj_examples)} examples from trajectory {i}")
        
        # Check each example for NaN values
        for j, example in enumerate(traj_examples):
            policy_target, value_target = example.target
            
            # Check policy target
            policy_values = list(policy_target.values())
            if any(np.isnan(val) for val in policy_values):
                print(f"❌ NaN found in policy target at trajectory {i}, example {j}")
                print(f"Policy: {policy_target}")
            
            # Check value target
            if np.isnan(value_target):
                print(f"❌ NaN found in value target at trajectory {i}, example {j}: {value_target}")
            
    print(f"✅ Total examples extracted: {len(examples)}")
    
except Exception as e:
    print(f"❌ Error extracting examples: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 3: Encode Examples with Tensor Mapping ===")
try:
    states, targets, extra_data = tensor_mapping.encode_examples(examples, device)
    
    print(f"Encoded states shape: {states.shape}")
    print(f"Target keys: {targets.keys()}")
    
    # Check for NaN in encoded data
    if torch.isnan(states).any():
        print("❌ NaN found in encoded states!")
        nan_indices = torch.isnan(states).nonzero()
        print(f"NaN locations in states: {nan_indices[:10]}...")  # Show first 10
    else:
        print("✅ Encoded states are valid")
    
    # Check policy targets
    policy_targets = targets['policy']
    if torch.isnan(policy_targets).any():
        print("❌ NaN found in policy targets!")
        nan_indices = torch.isnan(policy_targets).nonzero()
        print(f"NaN locations in policy targets: {nan_indices[:10]}...")
        
        # Print some examples of NaN policies
        nan_rows = torch.isnan(policy_targets).any(dim=1).nonzero().flatten()
        print(f"Rows with NaN policies: {nan_rows[:5]}")
        for row in nan_rows[:3]:
            print(f"Row {row}: {policy_targets[row]}")
    else:
        print("✅ Policy targets are valid")
        
    # Check value targets
    value_targets = targets['value']
    if torch.isnan(value_targets).any():
        print("❌ NaN found in value targets!")
        nan_indices = torch.isnan(value_targets).nonzero()
        print(f"NaN locations in value targets: {nan_indices[:10]}...")
        print(f"NaN values: {value_targets[torch.isnan(value_targets)][:10]}")
    else:
        print("✅ Value targets are valid")
        
    # Check extra data
    legal_actions = extra_data['legal_actions']
    print(f"Legal actions shape: {legal_actions.shape}")
    if torch.isnan(legal_actions.float()).any():
        print("❌ NaN found in legal actions!")
    else:
        print("✅ Legal actions are valid")
        
except Exception as e:
    print(f"❌ Error encoding examples: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 4: Model Forward Pass ===")
try:
    model.model.eval()
    with torch.no_grad():
        # Take a small batch to test
        batch_size = min(8, states.shape[0])
        batch_states = states[:batch_size]
        batch_targets = {k: v[:batch_size] for k, v in targets.items()}
        batch_extra_data = {k: v[:batch_size] for k, v in extra_data.items()}
        
        model_outputs = model.model(batch_states)
        
        print(f"Model output keys: {model_outputs.keys()}")
        
        # Check model outputs for NaN
        for key, tensor in model_outputs.items():
            if torch.isnan(tensor).any():
                print(f"❌ NaN found in model output '{key}'!")
                print(f"Shape: {tensor.shape}")
                print(f"NaN count: {torch.isnan(tensor).sum().item()}")
                nan_indices = torch.isnan(tensor).nonzero()
                print(f"First few NaN locations: {nan_indices[:5]}")
                
                # Show the actual values
                print(f"Sample values: {tensor.flatten()[:10]}")
            else:
                print(f"✅ Model output '{key}' is valid")
                print(f"  Shape: {tensor.shape}, Range: [{tensor.min().item():.4f}, {tensor.max().item():.4f}]")
        
except Exception as e:
    print(f"❌ Error in model forward pass: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 5: Compute Loss ===")
try:
    # Put model in training mode for loss computation
    model.model.train()
    
    # Compute loss using the training adapter
    loss, metrics = training_adapter.compute_loss(model_outputs, batch_targets, batch_extra_data)
    
    print(f"Total loss: {loss.item()}")
    if metrics:
        for metric_name, metric_value in metrics.items():
            print(f"{metric_name}: {metric_value}")
    
    # Check if loss is NaN
    if torch.isnan(loss):
        print("❌ Loss is NaN!")
        
        # Let's manually compute the loss components to find where NaN comes from
        print("\n--- Manual Loss Debugging ---")
        
        policy_logits = model_outputs["policy"]
        value_preds = model_outputs["value"]
        policy_targets = batch_targets['policy']
        value_targets = batch_targets['value']
        legal_actions_mask = batch_extra_data['legal_actions']
        
        print(f"Policy logits stats: min={policy_logits.min()}, max={policy_logits.max()}, nan_count={torch.isnan(policy_logits).sum()}")
        print(f"Value preds stats: min={value_preds.min()}, max={value_preds.max()}, nan_count={torch.isnan(value_preds).sum()}")
        print(f"Policy targets stats: min={policy_targets.min()}, max={policy_targets.max()}, nan_count={torch.isnan(policy_targets).sum()}")
        print(f"Value targets stats: min={value_targets.min()}, max={value_targets.max()}, nan_count={torch.isnan(value_targets).sum()}")
        
        # Policy loss computation
        try:
            import torch.nn.functional as F
            
            # Apply legal actions mask
            masked_policy_logits = policy_logits.clone()
            masked_policy_logits[~legal_actions_mask] = float('-inf')
            print(f"After masking: min={masked_policy_logits.min()}, max={masked_policy_logits.max()}")
            
            # Compute log probabilities
            log_probs = F.log_softmax(masked_policy_logits, dim=-1)
            print(f"Log probs: min={log_probs.min()}, max={log_probs.max()}, nan_count={torch.isnan(log_probs).sum()}")
            
            # Compute policy loss
            policy_loss = -torch.sum(policy_targets * log_probs, dim=-1).mean()
            print(f"Policy loss: {policy_loss.item()}, is_nan: {torch.isnan(policy_loss)}")
            
            # Compute value loss
            value_loss = F.mse_loss(value_preds.squeeze(), value_targets)
            print(f"Value loss: {value_loss.item()}, is_nan: {torch.isnan(value_loss)}")
            
        except Exception as e:
            print(f"Error in manual loss computation: {e}")
            import traceback
            traceback.print_exc()
        
    else:
        print("✅ Loss is valid")
        
except Exception as e:
    print(f"❌ Error computing loss: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 6: Detailed State Analysis ===")
# Let's look at some specific states that might be problematic
try:
    print("Analyzing some example states...")
    
    for i in range(min(3, len(examples))):
        example = examples[i]
        state = example.state
        policy_target, value_target = example.target
        
        print(f"\n--- Example {i} ---")
        print(f"Is terminal: {state.is_terminal}")
        print(f"Current player: {state.current_player}")
        print(f"Legal actions: {state.legal_actions}")
        print(f"Rewards: {state.rewards}")
        print(f"Policy target: {policy_target}")
        print(f"Value target: {value_target}")
        
        # Check for issues
        if np.isnan(value_target):
            print("❌ This example has NaN value target!")
        
        # Check if current_player is problematic
        if state.current_player not in [0, 1]:
            print(f"❌ Unusual current_player: {state.current_player}")
            
        # Check if rewards have issues
        for player, reward in state.rewards.items():
            if np.isnan(reward):
                print(f"❌ NaN reward for player {player}: {reward}")
                
except Exception as e:
    print(f"❌ Error in state analysis: {e}")
    import traceback
    traceback.print_exc()

print("\n=== Summary ===")
print("Check the output above for any ❌ markers indicating NaN values.")
print("Pay special attention to:")
print("1. Final trajectory rewards")
print("2. Value targets in examples")
print("3. Model outputs (especially if they produce extreme values)")
print("4. Loss computation components")

Using device: mps
=== STEP 1: Generate Trajectories ===
Generated 6 trajectories
Trajectory 0 final reward: -1.0
✅ Trajectory 0 final reward is valid
Trajectory 1 final reward: 1.0
✅ Trajectory 1 final reward is valid
Trajectory 2 final reward: -1.0
✅ Trajectory 2 final reward is valid
Trajectory 3 final reward: 1.0
✅ Trajectory 3 final reward is valid
Trajectory 4 final reward: -1.0
✅ Trajectory 4 final reward is valid
Trajectory 5 final reward: 1.0
✅ Trajectory 5 final reward is valid

=== STEP 2: Extract Examples ===
Extracted 8 examples from trajectory 0
Extracted 8 examples from trajectory 1
Extracted 15 examples from trajectory 2
Extracted 15 examples from trajectory 3
Extracted 11 examples from trajectory 4
Extracted 11 examples from trajectory 5
✅ Total examples extracted: 68

=== STEP 3: Encode Examples with Tensor Mapping ===
Encoded states shape: torch.Size([68, 84])
Target keys: dict_keys(['policy', 'value'])
✅ Encoded states are valid
✅ Policy targets are valid
✅ Value tar

Traceback (most recent call last):
  File "/var/folders/22/9ys3dbcs5yb0wn1j_xcbb4600000gn/T/ipykernel_39691/3483481202.py", line 289, in <module>
    for player, reward in state.rewards.items():
                          ^^^^^^^^^^^^^^^^^^^
AttributeError: 'function' object has no attribute 'items'


In [7]:
import torch
import numpy as np

print("=== Testing Policy Confidence vs NaN Values ===")

def analyze_policy_confidence(policy_tensor):
    """Analyze how confident/extreme a policy is"""
    probs = policy_tensor.cpu().numpy()
    
    max_prob = np.max(probs)
    min_prob = np.min(probs)
    
    # Calculate entropy (lower = more confident)
    # Avoid log(0) by adding small epsilon
    epsilon = 1e-8
    entropy = -np.sum(probs * np.log(probs + epsilon))
    
    # Count near-zero probabilities
    near_zero_count = np.sum(probs < 0.01)
    
    # Count near-one probabilities  
    near_one_count = np.sum(probs > 0.99)
    
    return {
        'max_prob': max_prob,
        'min_prob': min_prob,
        'entropy': entropy,
        'near_zero_count': near_zero_count,
        'near_one_count': near_one_count,
        'policy': probs
    }

# Collect data about policies with and without NaN values
nan_policies = []
valid_policies = []
nan_indices = []
valid_indices = []

print("Scanning buffer for NaN values and analyzing corresponding policies...")

# Scan through buffer (adjust range based on your buffer size)
for i in range(min(len(buffer), 20000)):  # Scan up to 20k entries
    try:
        state, targets, extra_data = buffer[i]
        policy = targets['policy']
        value = targets['value']
        
        if torch.isnan(value).any():
            # Found NaN value - analyze the policy
            policy_analysis = analyze_policy_confidence(policy)
            nan_policies.append(policy_analysis)
            nan_indices.append(i)
            
            if len(nan_policies) <= 10:  # Print first 10 for inspection
                print(f"\n❌ NaN Value at index {i}:")
                print(f"  Policy: {policy_analysis['policy']}")
                print(f"  Max prob: {policy_analysis['max_prob']:.6f}")
                print(f"  Min prob: {policy_analysis['min_prob']:.6f}")
                print(f"  Entropy: {policy_analysis['entropy']:.6f}")
                print(f"  Near-zero count: {policy_analysis['near_zero_count']}")
                print(f"  Near-one count: {policy_analysis['near_one_count']}")
        else:
            # Valid value - collect some samples for comparison
            if len(valid_policies) < 1000:  # Collect up to 1000 valid samples
                policy_analysis = analyze_policy_confidence(policy)
                valid_policies.append(policy_analysis)
                valid_indices.append(i)
                
    except Exception as e:
        print(f"Error accessing index {i}: {e}")
        break

print(f"\n=== Analysis Results ===")
print(f"Found {len(nan_policies)} entries with NaN values")
print(f"Collected {len(valid_policies)} valid entries for comparison")

if len(nan_policies) > 0 and len(valid_policies) > 0:
    # Calculate statistics
    nan_max_probs = [p['max_prob'] for p in nan_policies]
    nan_entropies = [p['entropy'] for p in nan_policies]
    nan_near_ones = [p['near_one_count'] for p in nan_policies]
    
    valid_max_probs = [p['max_prob'] for p in valid_policies]
    valid_entropies = [p['entropy'] for p in valid_policies]
    valid_near_ones = [p['near_one_count'] for p in valid_policies]
    
    print(f"\n--- Policy Confidence Comparison ---")
    print(f"NaN Policies:")
    print(f"  Average max probability: {np.mean(nan_max_probs):.6f} ± {np.std(nan_max_probs):.6f}")
    print(f"  Average entropy: {np.mean(nan_entropies):.6f} ± {np.std(nan_entropies):.6f}")
    print(f"  Average near-one count: {np.mean(nan_near_ones):.2f}")
    print(f"  Policies with prob=1.0: {sum(1 for p in nan_max_probs if p >= 0.999)}")
    
    print(f"\nValid Policies:")
    print(f"  Average max probability: {np.mean(valid_max_probs):.6f} ± {np.std(valid_max_probs):.6f}")
    print(f"  Average entropy: {np.mean(valid_entropies):.6f} ± {np.std(valid_entropies):.6f}")
    print(f"  Average near-one count: {np.mean(valid_near_ones):.2f}")
    print(f"  Policies with prob=1.0: {sum(1 for p in valid_max_probs if p >= 0.999)}")
    
    # Test hypothesis: Are NaN policies significantly more confident?
    if np.mean(nan_max_probs) > np.mean(valid_max_probs) + 0.1:
        print(f"\n✅ HYPOTHESIS CONFIRMED: NaN policies are significantly more confident!")
    else:
        print(f"\n❌ Hypothesis not confirmed: No significant difference in policy confidence")
    
    # Look for perfect policies (exactly 1.0 somewhere)
    perfect_nan_policies = [p for p in nan_policies if p['max_prob'] >= 0.999]
    perfect_valid_policies = [p for p in valid_policies if p['max_prob'] >= 0.999]
    
    print(f"\n--- Perfect Policies (prob ≥ 0.999) ---")
    print(f"NaN entries with perfect policies: {len(perfect_nan_policies)}/{len(nan_policies)} ({100*len(perfect_nan_policies)/len(nan_policies):.1f}%)")
    print(f"Valid entries with perfect policies: {len(perfect_valid_policies)}/{len(valid_policies)} ({100*len(perfect_valid_policies)/len(valid_policies):.1f}%)")

# Test the numerical issue hypothesis
print(f"\n=== Testing Numerical Issues with Confident Policies ===")

# Create a very confident policy and see if it causes NaN in loss computation
confident_policy = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device).unsqueeze(0)
moderate_policy = torch.tensor([0.7, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05], device=device).unsqueeze(0)

dummy_state = torch.randn(1, 84, device=device)
dummy_extra_data = {'legal_actions': torch.ones(1, 7, dtype=torch.bool, device=device)}

# Test with confident policy
try:
    dummy_targets_confident = {'policy': confident_policy, 'value': torch.tensor([1.0], device=device)}
    
    # Simulate model output
    model_output = {
        'policy': torch.randn(1, 7, device=device),  # Random logits
        'value': torch.tensor([0.5], device=device)
    }
    
    loss_confident, metrics_confident = training_adapter.compute_loss(
        model_output, dummy_targets_confident, dummy_extra_data
    )
    
    print(f"Loss with confident policy target: {loss_confident.item()}")
    if torch.isnan(loss_confident):
        print("❌ Confident policy target causes NaN loss!")
    else:
        print("✅ Confident policy target works fine")
        
except Exception as e:
    print(f"❌ Error with confident policy: {e}")

# Test with moderate policy
try:
    dummy_targets_moderate = {'policy': moderate_policy, 'value': torch.tensor([1.0], device=device)}
    
    loss_moderate, metrics_moderate = training_adapter.compute_loss(
        model_output, dummy_targets_moderate, dummy_extra_data
    )
    
    print(f"Loss with moderate policy target: {loss_moderate.item()}")
    
except Exception as e:
    print(f"❌ Error with moderate policy: {e}")

# Test what happens when model outputs extreme logits
print(f"\n--- Testing Extreme Model Outputs ---")
try:
    # Very confident model output
    extreme_model_output = {
        'policy': torch.tensor([[10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]], device=device),
        'value': torch.tensor([0.5], device=device)
    }
    
    loss_extreme, metrics_extreme = training_adapter.compute_loss(
        extreme_model_output, dummy_targets_moderate, dummy_extra_data
    )
    
    print(f"Loss with extreme model output: {loss_extreme.item()}")
    if torch.isnan(loss_extreme):
        print("❌ Extreme model output causes NaN loss!")
    else:
        print("✅ Extreme model output handled correctly")
        
except Exception as e:
    print(f"❌ Error with extreme model output: {e}")

=== Testing Policy Confidence vs NaN Values ===
Scanning buffer for NaN values and analyzing corresponding policies...
Error accessing index 0: 'ReplayBuffer' object is not subscriptable

=== Analysis Results ===
Found 0 entries with NaN values
Collected 0 valid entries for comparison

=== Testing Numerical Issues with Confident Policies ===
Loss with confident policy target: 1.5551072359085083
✅ Confident policy target works fine
Loss with moderate policy target: 0.7357421517372131

--- Testing Extreme Model Outputs ---
Loss with extreme model output: 5.101607799530029
✅ Extreme model output handled correctly


In [8]:
def setup_logging(log_level: str = "INFO"):
    """Setup logging configuration."""
    logging.basicConfig(
        level=getattr(logging, log_level.upper()),
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

logger = setup_logging()

In [9]:
import torch
import pyspiel
import logging
import numpy as np
from typing import Dict, List
import itertools

# Import your modules
from core.games.open_spiel_state_wrapper import OpenSpielState
from core.algorithms.AlphaZero import AlphaZero, AlphaZeroConfig, AlphaZeroTrainingAdapter
from core.model_interface import Model, ModelPredictor
from core.simulation import generate_trajectories
from core.data_structures import ReplayBuffer
from experiments.connect_four.tensor_mapping import ConnectFourTensorMapping
from experiments.connect_four.models.resmlp import ResMLP, ResMLPInitParams

print("=== COMPREHENSIVE TRAINING SCRIPT SIMULATION ===")

# Setup - exactly matching train.py
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def state_factory():
    game = pyspiel.load_game("connect_four")
    return OpenSpielState(game.new_initial_state(), num_players=2)

# EXACT model parameters from train.py
model_params = ResMLPInitParams(
    input_dim=2 * 6 * 7, 
    num_residual_blocks=10, 
    residual_dim=64, 
    hidden_size=256, 
    policy_head_dim=7
)

print(f"Model parameters: {model_params}")

# Initialize model exactly as in train.py
model = Model(
    model_architecture=ResMLP,
    init_params=model_params,
    device=device
)

# EXACT training components from train.py
tensor_mapping = ConnectFourTensorMapping()
training_adapter = AlphaZeroTrainingAdapter()
alphazero_config = AlphaZeroConfig()  # Default config as in train.py
model_predictor = ModelPredictor(model, tensor_mapping)

# EXACT optimizer from train.py
optimizer = torch.optim.Adam(
    model.model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=1e-4,
    amsgrad=False
)

# EXACT learning rate scheduler from train.py
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.9,
    patience=1_000,
    cooldown=1_000,
    min_lr=1e-5
)

# EXACT buffer parameters from train.py – NOTE: changed from 256 to 32
learning_batch_size = 32
buffer_max_size = 10 * 32  # 25,600
learning_min_buffer_size = 32  # 7,680
learning_min_new_examples_per_step = 128

# Initialize replay buffer
replay_buffer = ReplayBuffer(buffer_max_size, device=device)

print(f"Training parameters:")
print(f"  Batch size: {learning_batch_size}")
print(f"  Buffer max size: {buffer_max_size}")
print(f"  Buffer min size: {learning_min_buffer_size}")
print(f"  Min new examples per step: {learning_min_new_examples_per_step}")

print("\n=== STEP 1: Generate Initial Training Data ===")
# Generate enough data to reach minimum buffer size
total_examples_needed = learning_min_buffer_size
games_to_generate = total_examples_needed // 20  # Estimate ~20 examples per game

print(f"Generating {games_to_generate} games to reach minimum buffer size...")

all_examples = []
for game_batch in range(games_to_generate):  # Process in batches to avoid memory issues
    try:
        trajectories = generate_trajectories(
            initial_state_creator=state_factory,
            trajectory_player_creators=[
                lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config),
                lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config)
            ],
            opponent_creators=[],
            num_games=2,
            logger=logger
        )
        
        # Extract examples
        batch_examples = []
        for trajectory in trajectories:
            batch_examples.extend(training_adapter.extract_examples(trajectory))
        
        all_examples.extend(batch_examples)
        print(f"Generated batch {game_batch + 1}, total examples: {len(all_examples)}")
        
        # Check for NaN in trajectory rewards
        for i, trajectory in enumerate(trajectories):
            if np.isnan(trajectory[-1].reward):
                print(f"❌ NaN in trajectory {game_batch + i} final reward!")
                
    except Exception as e:
        print(f"❌ Error generating game batch {game_batch//10 + 1}: {e}")
        import traceback
        traceback.print_exc()

print(f"Generated {len(all_examples)} total examples")

# Check examples for NaN before encoding
print("\n=== STEP 2: Check Examples Before Encoding ===")
nan_examples = 0
for i, example in enumerate(all_examples):
    policy_target, value_target = example.target
    if np.isnan(value_target) or any(np.isnan(v) for v in policy_target.values()):
        print(f"❌ NaN in example {i}: value={value_target}, policy_has_nan={any(np.isnan(v) for v in policy_target.values())}")
        nan_examples += 1
        
print(f"Found {nan_examples} examples with NaN out of {len(all_examples)}")

print("\n=== STEP 3: Encode and Fill Buffer ===")
try:
    # Encode examples in batches to match training script behavior
    batch_size = 1000
    for i in range(0, len(all_examples), batch_size):
        batch_examples = all_examples[i:i+batch_size]
        states, targets, extra_data = tensor_mapping.encode_examples(batch_examples, device)
        replay_buffer.add(states, targets, extra_data)
        
        # Check for NaN in encoded data
        if torch.isnan(targets['value']).any():
            nan_indices = torch.isnan(targets['value']).nonzero().flatten()
            print(f"❌ NaN values found in batch {i//batch_size + 1} at indices: {nan_indices[:5]}...")
        
        print(f"Added batch {i//batch_size + 1}, buffer size: {len(replay_buffer)}")
        
        if len(replay_buffer) >= learning_min_buffer_size:
            break
            
    print(f"✅ Buffer filled to {len(replay_buffer)} examples")
    
except Exception as e:
    print(f"❌ Error filling buffer: {e}")
    import traceback
    traceback.print_exc()

print("\n=== STEP 4: Simulate Training Loop ===")
# Simulate the exact training loop from training.py
training_steps = 20  # Enough to see if NaN develops
losses = []

print("Starting training simulation...")
for step in range(training_steps):
    try:
        print(f"\n--- Training Step {step + 1} ---")
        
        # Sample batch exactly as in training script
        states, targets, extra_data = replay_buffer.sample(learning_batch_size)
        
        # Check sampled data for NaN
        if torch.isnan(targets['value']).any():
            nan_count = torch.isnan(targets['value']).sum().item()
            print(f"❌ Sampled batch contains {nan_count} NaN values!")
            
        if torch.isnan(targets['policy']).any():
            nan_count = torch.isnan(targets['policy']).sum().item()
            print(f"❌ Sampled batch contains {nan_count} NaN policies!")
        
        # EXACT training step from training.py
        model.model.train()
        
        # Use autocast exactly as in training script
        with torch.autocast(device_type=device.type):
            model_outputs = model.model(states)
            
            # Check model outputs
            if torch.isnan(model_outputs['policy']).any():
                print(f"❌ Model output policy contains NaN!")
                print(f"Policy stats: min={model_outputs['policy'].min()}, max={model_outputs['policy'].max()}")
                
            if torch.isnan(model_outputs['value']).any():
                print(f"❌ Model output value contains NaN!")
                print(f"Value stats: min={model_outputs['value'].min()}, max={model_outputs['value'].max()}")
            
            loss, metrics = training_adapter.compute_loss(model_outputs, targets, extra_data)
        
        # Check loss
        if torch.isnan(loss):
            print(f"❌ Loss is NaN at step {step + 1}!")
            print("Debugging loss components...")
            
            # Manual loss debugging
            policy_logits = model_outputs["policy"]
            value_preds = model_outputs["value"]
            policy_targets = targets['policy']
            value_targets = targets['value']
            legal_actions_mask = extra_data['legal_actions']
            
            print(f"  Policy logits: min={policy_logits.min()}, max={policy_logits.max()}, nan_count={torch.isnan(policy_logits).sum()}")
            print(f"  Value preds: min={value_preds.min()}, max={value_preds.max()}, nan_count={torch.isnan(value_preds).sum()}")
            print(f"  Policy targets: min={policy_targets.min()}, max={policy_targets.max()}, nan_count={torch.isnan(policy_targets).sum()}")
            print(f"  Value targets: min={value_targets.min()}, max={value_targets.max()}, nan_count={torch.isnan(value_targets).sum()}")
            
            break
        
        losses.append(loss.item())
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Check model parameters for NaN after update
        nan_params = []
        for name, param in model.model.named_parameters():
            if torch.isnan(param).any():
                nan_params.append(name)
                
        if nan_params:
            print(f"❌ NaN in model parameters: {nan_params}")
            break
        
        # Update learning rate (every step to match training script)
        if step > 0:
            lr_scheduler.step(loss.item())
            current_lr = optimizer.param_groups[0]['lr']
            
        print(f"Step {step + 1}: Loss = {loss.item():.6f}")
        if metrics:
            for metric_name, metric_value in metrics.items():
                print(f"  {metric_name}: {metric_value:.6f}")
        
        # Simulate adding new examples (as would happen in multi-actor setup)
        if step % 5 == 0 and step > 0:  # Every 5 steps, add some new examples
            print("Simulating new examples from actors...")
            try:
                new_trajectories = generate_trajectories(
                    initial_state_creator=state_factory,
                    trajectory_player_creators=[
                        lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config),
                        lambda state: training_adapter.create_tree_search(state.clone(), model_predictor, alphazero_config)
                    ],
                    opponent_creators=[],
                    num_games=2,
                    logger=logger
                )
                
                new_examples = []
                for trajectory in new_trajectories:
                    new_examples.extend(training_adapter.extract_examples(trajectory))
                
                # Check new examples for NaN
                new_nan_count = 0
                for example in new_examples:
                    _, value_target = example.target
                    if np.isnan(value_target):
                        new_nan_count += 1
                
                if new_nan_count > 0:
                    print(f"❌ {new_nan_count} new examples contain NaN values!")
                
                # Add to buffer
                states_new, targets_new, extra_data_new = tensor_mapping.encode_examples(new_examples, device)
                replay_buffer.add(states_new, targets_new, extra_data_new)
                
                print(f"Added {len(new_examples)} new examples, buffer size: {len(replay_buffer)}")

                all_examples.extend(new_examples)
                
            except Exception as e:
                print(f"❌ Error generating new examples: {e}")
        
    except Exception as e:
        print(f"❌ Error in training step {step + 1}: {e}")
        import traceback
        traceback.print_exc()
        break

print(f"\n=== STEP 5: Final Buffer Analysis ===")
# Check final buffer state
print("Analyzing final buffer state...")

# Sample from different parts of the buffer to check for corruption
buffer_check_indices = [0, len(replay_buffer)//4, len(replay_buffer)//2, 3*len(replay_buffer)//4, len(replay_buffer)-1]

for idx in buffer_check_indices:
    try:
        # Sample around this index
        sample_states, sample_targets, sample_extra_data = replay_buffer.sample(min(10, len(replay_buffer)))
        
        nan_values = torch.isnan(sample_targets['value']).sum().item()
        nan_policies = torch.isnan(sample_targets['policy']).sum().item()
        
        print(f"Buffer region around index {idx}: {nan_values} NaN values, {nan_policies} NaN policies")
        
    except Exception as e:
        print(f"Error checking buffer region {idx}: {e}")

print(f"\n=== Training Summary ===")
print(f"Completed {len(losses)} training steps")
print(f"Loss progression: {losses[:5]} ... {losses[-5:] if len(losses) > 5 else losses}")
print(f"Final buffer size: {len(replay_buffer)}")

if losses:
    print(f"Loss range: {min(losses):.6f} to {max(losses):.6f}")
    
    # Check for loss explosion
    if any(loss > 100 for loss in losses):
        print("❌ Loss explosion detected!")
    
    if any(np.isnan(loss) for loss in losses):
        print("❌ NaN losses detected!")
else:
    print("❌ No successful training steps completed!")

print("\n=== Device-Specific Checks ===")
if device.type == 'mps':
    print("Running on MPS - checking for known MPS issues...")
    
    # Test autocast specifically on MPS
    try:
        test_tensor = torch.randn(10, 84, device=device)
        with torch.autocast(device_type='mps'):
            output = model.model(test_tensor)
            if torch.isnan(output['policy']).any() or torch.isnan(output['value']).any():
                print("❌ MPS autocast producing NaN outputs!")
            else:
                print("✅ MPS autocast working correctly")
    except Exception as e:
        print(f"❌ MPS autocast error: {e}")

print("\n=== CONCLUSION ===")
print("Check the output above for:")
print("1. ❌ NaN in trajectory rewards (data generation issue)")
print("2. ❌ NaN in model outputs (model instability)")
print("3. ❌ NaN in loss computation (numerical instability)")
print("4. ❌ NaN in model parameters (training instability)")
print("5. ❌ Loss explosion or MPS-specific issues")

=== COMPREHENSIVE TRAINING SCRIPT SIMULATION ===
Using device: mps
Model parameters: {'input_dim': 84, 'num_residual_blocks': 10, 'residual_dim': 64, 'hidden_size': 256, 'policy_head_dim': 7}
Training parameters:
  Batch size: 32
  Buffer max size: 320
  Buffer min size: 32
  Min new examples per step: 128

=== STEP 1: Generate Initial Training Data ===
Generating 1 games to reach minimum buffer size...
Generated batch 1, total examples: 66
Generated 66 total examples

=== STEP 2: Check Examples Before Encoding ===
Found 0 examples with NaN out of 66

=== STEP 3: Encode and Fill Buffer ===
Added batch 1, buffer size: 66
✅ Buffer filled to 66 examples

=== STEP 4: Simulate Training Loop ===
Starting training simulation...

--- Training Step 1 ---
Step 1: Loss = 0.823926
  policy_loss: 0.482449
  value_loss: 0.341477

--- Training Step 2 ---
Step 2: Loss = 1.175985
  policy_loss: 0.440660
  value_loss: 0.735325

--- Training Step 3 ---
Step 3: Loss = 1.230045
  policy_loss: 0.576644
  va

In [10]:
import torch
import numpy as np

print("=== DETAILED INSPECTION OF CORRUPTED EXAMPLES ===")

def inspect_corrupted_examples(all_examples):
    """Inspect examples and show detailed data for corrupted ones."""
    
    corrupted_examples = []
    valid_examples = []
    
    for i, example in enumerate(all_examples):
        policy_dict, value = example.target
        
        # Check for NaN in value
        value_is_nan = np.isnan(value) if isinstance(value, (int, float)) else torch.isnan(value).any()
        
        # Check for NaN in policy
        policy_has_nan = False
        if isinstance(policy_dict, dict):
            policy_has_nan = any(np.isnan(v) if isinstance(v, (int, float)) else torch.isnan(v).any() 
                               for v in policy_dict.values())
        
        if value_is_nan or policy_has_nan:
            corrupted_examples.append((i, example))
        else:
            valid_examples.append((i, example))
    
    print(f"Total examples: {len(all_examples)}")
    print(f"Valid examples: {len(valid_examples)}")
    print(f"Corrupted examples: {len(corrupted_examples)}")
    
    if corrupted_examples:
        print(f"\n=== DETAILED ANALYSIS OF CORRUPTED EXAMPLES ===")
        
        for idx, (example_idx, example) in enumerate(corrupted_examples[:5]):  # Show first 5 corrupted
            print(f"\n--- Corrupted Example #{idx+1} (Index {example_idx}) ---")
            
            # Extract target components
            policy_dict, value = example.target
            
            print(f"State type: {type(example.state)}")
            print(f"State:\n{example.state}")
            
            # Check state properties if it's an OpenSpielState
            if hasattr(example.state, 'current_player'):
                print(f"Current player: {example.state.current_player}")
            if hasattr(example.state, 'is_terminal'):
                print(f"Is terminal: {example.state.is_terminal}")
            if hasattr(example.state, 'rewards'):
                print(f"State rewards: {example.state.rewards}")
            if hasattr(example.state, 'legal_actions'):
                print(f"Legal actions: {example.state.legal_actions}")
            
            # Analyze target
            print(f"\nTarget Analysis:")
            print(f"Policy type: {type(policy_dict)}")
            print(f"Policy: {policy_dict}")
            
            if isinstance(policy_dict, dict):
                policy_sum = sum(policy_dict.values())
                print(f"Policy sum: {policy_sum}")
                nan_actions = [action for action, prob in policy_dict.items() if np.isnan(prob)]
                if nan_actions:
                    print(f"❌ Actions with NaN probabilities: {nan_actions}")
                    
            print(f"Value type: {type(value)}")
            print(f"Value: {value}")
            if np.isnan(value):
                print(f"❌ Value is NaN!")
                
            # Analyze extra_data
            print(f"\nExtra Data:")
            for key, val in example.extra_data.items():
                print(f"  {key}: {val}")
            
            print("-" * 80)
    
    # Show some valid examples for comparison
    if valid_examples and corrupted_examples:
        print(f"\n=== COMPARISON: VALID EXAMPLES ===")
        
        for idx, (example_idx, example) in enumerate(valid_examples[:2]):  # Show first 2 valid
            print(f"\n--- Valid Example #{idx+1} (Index {example_idx}) ---")
            
            policy_dict, value = example.target
            
            print(f"State:\n{example.state}")
            if hasattr(example.state, 'current_player'):
                print(f"Current player: {example.state.current_player}")
            if hasattr(example.state, 'is_terminal'):
                print(f"Is terminal: {example.state.is_terminal}")
            if hasattr(example.state, 'rewards'):
                print(f"State rewards: {example.state.rewards}")
                
            print(f"Policy: {policy_dict}")
            print(f"Policy sum: {sum(policy_dict.values()) if isinstance(policy_dict, dict) else 'N/A'}")
            print(f"Value: {value}")
            print("-" * 40)
    
    return corrupted_examples, valid_examples

# Inspect your examples
corrupted, valid = inspect_corrupted_examples(all_examples)

# Additional analysis
if corrupted:
    print(f"\n=== CORRUPTION PATTERNS ===")
    
    # Check if corruption happens at specific indices
    corrupted_indices = [idx for idx, _ in corrupted]
    print(f"Corrupted example indices: {corrupted_indices[:10]}...")  # Show first 10
    
    # Check if corruption correlates with specific states
    terminal_states = 0
    non_terminal_states = 0
    unknown_states = 0
    
    for _, example in corrupted:
        if hasattr(example.state, 'is_terminal'):
            if example.state.is_terminal:
                terminal_states += 1
            else:
                non_terminal_states += 1
        else:
            unknown_states += 1
    
    print(f"Corrupted examples by state type:")
    print(f"  Terminal states: {terminal_states}")
    print(f"  Non-terminal states: {non_terminal_states}")
    print(f"  Unknown state type: {unknown_states}")
    
    # Check corruption types
    value_nan_count = 0
    policy_nan_count = 0
    both_nan_count = 0
    
    for _, example in corrupted:
        policy_dict, value = example.target
        
        value_is_nan = np.isnan(value)
        policy_has_nan = any(np.isnan(v) for v in policy_dict.values()) if isinstance(policy_dict, dict) else False
        
        if value_is_nan and policy_has_nan:
            both_nan_count += 1
        elif value_is_nan:
            value_nan_count += 1
        elif policy_has_nan:
            policy_nan_count += 1
    
    print(f"Corruption breakdown:")
    print(f"  Value NaN only: {value_nan_count}")
    print(f"  Policy NaN only: {policy_nan_count}")
    print(f"  Both NaN: {both_nan_count}")

=== DETAILED INSPECTION OF CORRUPTED EXAMPLES ===
Total examples: 226
Valid examples: 226
Corrupted examples: 0
