# DQN Training Audit - Verify Q-Value Updates

This notebook tests the DQN agent's training mechanism to verify:
1. Q-values increase for winning moves after training
2. Training propagates backward to previous states (credit assignment)
3. The Bellman equation is implemented correctly

## Test Scenario
- Create a win-in-1 position
- Train on the winning move
- Verify Q-value increases toward target (reward = 1.0)
- Train on the previous move that led to win-in-1
- Verify Q-value increases (should learn to get into winning positions)

In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt

# Add project root to path
sys.path.append('..')

from src.environment.connect4 import Connect4Env
from src.agents.dqn_agent import DQNAgent

ImportError: cannot import name 'Connect4Env' from 'src.environment.connect4' (/Users/cmcol/Development/github/connect4-rl/notebooks/../src/environment/connect4.py)

## 1. Setup Environment and Agent

In [None]:
# Create environment
env = Connect4Env()

# Create DQN agent with small network for faster testing
agent = DQNAgent(
    name="DQN-Audit",
    player_id=1,
    conv_channels=[16, 32],  # Smaller network
    fc_dims=[64],
    learning_rate=1e-3,  # Higher LR for faster learning in test
    gamma=0.99,
    batch_size=1,  # Train on single experience
    min_buffer_size=1,  # Allow training immediately
    buffer_size=1000,
    target_update_freq=1000,
    use_double_dqn=True
)

print(f"Agent created: {agent}")
print(f"Device: {agent.device}")

## 2. Create Win-in-1 Test Scenario

We'll create a position where Player 1 can win by playing column 3.

In [None]:
def create_win_in_1_scenario():
    """
    Create a win-in-1 position for Player 1.
    
    Board layout (Player 1 = X, Player -1 = O):
    . . . . . . .
    . . . . . . .
    . . . . . . .
    . . . . . . .
    . . . . . . .
    X X X . O O .  <- Player 1 can win at column 3
    0 1 2 3 4 5 6
    """
    env.reset()
    
    # Player 1 plays columns 0, 1, 2
    env.play_move(0)  # Player 1
    env.play_move(4)  # Player -1
    env.play_move(1)  # Player 1
    env.play_move(5)  # Player -1
    env.play_move(2)  # Player 1
    
    # Now it's Player -1's turn, but we'll switch to Player 1 for testing
    env.current_player = 1
    
    return env

# Create scenario
env = create_win_in_1_scenario()
state_before_win = env.get_state()

# Visualize
print("Win-in-1 Position (Player 1's turn):")
env.render()
print(f"\nPlayer 1 can win by playing column 3")
print(f"Legal moves: {env.get_legal_moves()}")

## 3. Test 1: Train on Winning Move

Train the agent on the winning move and verify Q-value increases.

In [None]:
# Get state before winning move
state_before_win = env.get_state()

# Make winning move
winning_action = 3
next_state, reward, done, info = env.step(winning_action)

print("After winning move:")
env.render()
print(f"\nReward: {reward}")
print(f"Done: {done}")
print(f"Winner: {info.get('winner')}")

# Audit the training step
print("\n" + "="*60)
print("AUDITING TRAINING ON WINNING MOVE")
print("="*60)

audit_results = agent.audit_training_step(
    state=state_before_win,
    action=winning_action,
    reward=reward,
    next_state=next_state,
    done=done
)

# Display results
print(f"\nAction taken: Column {winning_action}")
print(f"Reward: {audit_results['reward']}")
print(f"Done: {audit_results['done']}")
print(f"\nQ-values for all actions BEFORE training:")
for col in range(7):
    marker = " <-- WINNING MOVE" if col == winning_action else ""
    print(f"  Column {col}: {audit_results['q_before'][col]:+.4f}{marker}")

print(f"\nTraining metrics:")
print(f"  Target Q-value: {audit_results['target_q']:+.4f}")
print(f"  Q-value before: {audit_results['q_action_before']:+.4f}")
print(f"  Q-value after:  {audit_results['q_action_after']:+.4f}")
print(f"  Change:         {audit_results['q_change']:+.4f}")
print(f"  TD Error:       {audit_results['td_error']:+.4f}")
print(f"  Loss:           {audit_results['loss']:.6f}")

print(f"\nQ-values for all actions AFTER training:")
for col in range(7):
    marker = " <-- WINNING MOVE" if col == winning_action else ""
    print(f"  Column {col}: {audit_results['q_after'][col]:+.4f}{marker}")

# Verify learning
print("\n" + "="*60)
print("VERIFICATION")
print("="*60)
if audit_results['q_change'] > 0:
    print("✓ SUCCESS: Q-value increased toward target (reward = 1.0)")
    print(f"  Q-value moved from {audit_results['q_action_before']:.4f} toward {audit_results['target_q']:.4f}")
else:
    print("✗ FAILURE: Q-value did not increase")
    print(f"  This indicates a problem with the training implementation")

## 4. Test 2: Credit Assignment - Train on Previous Move

Now test if training propagates backward. Train on the move that led to the win-in-1 position.

In [None]:
# Create the position BEFORE the win-in-1 (one move earlier)
def create_pre_win_scenario():
    """
    Create position one move before win-in-1.
    
    Board layout:
    . . . . . . .
    . . . . . . .
    . . . . . . .
    . . . . . . .
    . . . . . . .
    X X . . O O .  <- Player 1 plays column 2 to create win-in-1
    0 1 2 3 4 5 6
    """
    env.reset()
    
    # Player 1 plays columns 0, 1
    env.play_move(0)  # Player 1
    env.play_move(4)  # Player -1
    env.play_move(1)  # Player 1
    env.play_move(5)  # Player -1
    
    # Now it's Player 1's turn to play column 2
    return env

# Create scenario
env = create_pre_win_scenario()
state_pre_win = env.get_state()

print("Position BEFORE win-in-1 (Player 1's turn):")
env.render()
print(f"\nPlayer 1 should play column 2 to create win-in-1 threat")

# Make the move that creates win-in-1
setup_action = 2
state_after_setup, _, done_setup, _ = env.step(setup_action)

print("\nAfter playing column 2:")
env.render()
print(f"\nNow Player 1 has win-in-1 at column 3 (on next turn)")

In [None]:
# Get Q-values for the setup move BEFORE training
q_before_credit = agent.get_q_values(state_pre_win)

print("Q-values BEFORE training on setup move:")
for col in range(7):
    marker = " <-- SETUP MOVE" if col == setup_action else ""
    print(f"  Column {col}: {q_before_credit[col]:+.4f}{marker}")

# Now audit training on this setup move
# The reward is 0 (game continues), but the next state has high Q-value (win-in-1)
print("\n" + "="*60)
print("AUDITING TRAINING ON SETUP MOVE (CREDIT ASSIGNMENT)")
print("="*60)

audit_results_2 = agent.audit_training_step(
    state=state_pre_win,
    action=setup_action,
    reward=0.0,  # Game continues
    next_state=state_after_setup,
    done=False
)

print(f"\nAction taken: Column {setup_action}")
print(f"Reward: {audit_results_2['reward']} (game continues)")
print(f"Done: {audit_results_2['done']}")

print(f"\nTraining metrics:")
print(f"  Target Q-value: {audit_results_2['target_q']:+.4f}")
print(f"  Q-value before: {audit_results_2['q_action_before']:+.4f}")
print(f"  Q-value after:  {audit_results_2['q_action_after']:+.4f}")
print(f"  Change:         {audit_results_2['q_change']:+.4f}")
print(f"  TD Error:       {audit_results_2['td_error']:+.4f}")
print(f"  Loss:           {audit_results_2['loss']:.6f}")

print(f"\nQ-values for all actions AFTER training:")
for col in range(7):
    marker = " <-- SETUP MOVE" if col == setup_action else ""
    print(f"  Column {col}: {audit_results_2['q_after'][col]:+.4f}{marker}")

# Verify credit assignment
print("\n" + "="*60)
print("CREDIT ASSIGNMENT VERIFICATION")
print("="*60)

# The target should be: reward + gamma * max_Q(next_state)
# Since we trained on the winning move, the next state should have high Q-value
next_state_q_values = agent.get_q_values(state_after_setup)
max_next_q = np.max(next_state_q_values)

print(f"\nBellman equation breakdown:")
print(f"  Reward: {audit_results_2['reward']:.4f}")
print(f"  Gamma: {agent.gamma:.4f}")
print(f"  Max Q(next_state): {max_next_q:+.4f}")
print(f"  Expected target: {audit_results_2['reward']} + {agent.gamma} * {max_next_q:.4f} = {audit_results_2['reward'] + agent.gamma * max_next_q:.4f}")
print(f"  Actual target:   {audit_results_2['target_q']:+.4f}")

if audit_results_2['q_change'] > 0:
    print("\n✓ SUCCESS: Credit assignment working!")
    print(f"  Q-value for setup move increased from {audit_results_2['q_action_before']:.4f} to {audit_results_2['q_action_after']:.4f}")
    print(f"  This means the agent is learning to value moves that lead to winning positions")
else:
    print("\n✗ FAILURE: Credit assignment not working")
    print(f"  Q-value did not increase for the setup move")

## 5. Visualize Q-Value Changes

In [None]:
# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Winning move Q-values
ax1 = axes[0]
columns = np.arange(7)
width = 0.35

ax1.bar(columns - width/2, audit_results['q_before'], width, label='Before Training', alpha=0.8)
ax1.bar(columns + width/2, audit_results['q_after'], width, label='After Training', alpha=0.8)
ax1.axhline(y=audit_results['target_q'], color='r', linestyle='--', label=f'Target ({audit_results["target_q"]:.3f})')

ax1.set_xlabel('Column')
ax1.set_ylabel('Q-Value')
ax1.set_title('Test 1: Training on Winning Move (Column 3)')
ax1.set_xticks(columns)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Highlight winning column
ax1.axvspan(winning_action - 0.5, winning_action + 0.5, alpha=0.2, color='green')

# Plot 2: Setup move Q-values
ax2 = axes[1]
ax2.bar(columns - width/2, audit_results_2['q_before'], width, label='Before Training', alpha=0.8)
ax2.bar(columns + width/2, audit_results_2['q_after'], width, label='After Training', alpha=0.8)
ax2.axhline(y=audit_results_2['target_q'], color='r', linestyle='--', label=f'Target ({audit_results_2["target_q"]:.3f})')

ax2.set_xlabel('Column')
ax2.set_ylabel('Q-Value')
ax2.set_title('Test 2: Training on Setup Move (Column 2)')
ax2.set_xticks(columns)
ax2.legend()
ax2.grid(True, alpha=0.3)

# Highlight setup column
ax2.axvspan(setup_action - 0.5, setup_action + 0.5, alpha=0.2, color='blue')

plt.tight_layout()
plt.show()

print("\nVisualization shows:")
print("  - Green highlight: Winning move (should increase toward +1.0)")
print("  - Blue highlight: Setup move (should increase based on next state's value)")
print("  - Red dashed line: Target Q-value from Bellman equation")

## 6. Summary and Conclusions

In [None]:
print("="*60)
print("AUDIT SUMMARY")
print("="*60)

print("\n1. WINNING MOVE TEST:")
print(f"   - Q-value change: {audit_results['q_change']:+.4f}")
print(f"   - Target Q-value: {audit_results['target_q']:+.4f}")
print(f"   - TD Error: {audit_results['td_error']:+.4f}")
if audit_results['q_change'] > 0:
    print("   ✓ Q-value increased toward reward (+1.0)")
else:
    print("   ✗ Q-value did not increase")

print("\n2. CREDIT ASSIGNMENT TEST:")
print(f"   - Q-value change: {audit_results_2['q_change']:+.4f}")
print(f"   - Target Q-value: {audit_results_2['target_q']:+.4f}")
print(f"   - TD Error: {audit_results_2['td_error']:+.4f}")
if audit_results_2['q_change'] > 0:
    print("   ✓ Q-value increased for move leading to winning position")
    print("   ✓ Credit assignment is working")
else:
    print("   ✗ Q-value did not increase")
    print("   ✗ Credit assignment may not be working")

print("\n3. BELLMAN EQUATION VERIFICATION:")
expected_target = audit_results_2['reward'] + agent.gamma * max_next_q
target_error = abs(audit_results_2['target_q'] - expected_target)
print(f"   - Expected target: {expected_target:.4f}")
print(f"   - Actual target:   {audit_results_2['target_q']:.4f}")
print(f"   - Error:           {target_error:.6f}")
if target_error < 0.01:
    print("   ✓ Bellman equation implemented correctly")
else:
    print("   ✗ Bellman equation may have issues")

print("\n" + "="*60)
if audit_results['q_change'] > 0 and audit_results_2['q_change'] > 0:
    print("OVERALL: ✓ DQN training is working correctly!")
    print("The agent learns to value winning moves and moves that lead to wins.")
else:
    print("OVERALL: ✗ DQN training has issues that need investigation")
print("="*60)