# Connect Four DQN Training Visualization

This notebook trains a Deep Q-Network agent to play Connect Four and visualizes the learning process.

In [None]:
import sys
sys.path.insert(0, '../src')

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

from game import ConnectFour, RandomAgent
from agent import DQNAgent

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Initialize the Agent

Create a DQN agent with the specified hyperparameters.

In [None]:
agent = DQNAgent(
    learning_rate=0.001,
    gamma=0.99,
    epsilon_start=1.0,
    epsilon_end=0.01,
    epsilon_decay=0.9998,
    buffer_size=100000,
    batch_size=64,
    target_update_freq=1000
)

print(f"Agent initialized")
print(f"Device: {agent.device}")
print(f"Initial epsilon: {agent.epsilon}")

## 2. Training Loop with Live Visualization

Train the agent through self-play and track statistics.

In [None]:
def train_with_stats(agent, num_episodes=10000, print_every=500):
    """Train agent and collect statistics."""
    stats = {
        'episode': [],
        'p1_win_rate': [],
        'p2_win_rate': [],
        'draw_rate': [],
        'epsilon': [],
        'avg_game_length': [],
        'avg_loss': []
    }
    
    # Rolling counters
    p1_wins = p2_wins = draws = total_moves = 0
    losses = []
    
    for episode in range(1, num_episodes + 1):
        game = ConnectFour()
        episode_transitions = []
        
        while not game.done:
            state = game.get_state()
            legal_actions = game.get_legal_actions()
            current_player = game.current_player
            
            action = agent.choose_action(state, legal_actions, training=True)
            next_state, reward, done = game.make_move(action)
            
            episode_transitions.append({
                'state': state,
                'action': action,
                'player': current_player,
                'next_state': next_state,
                'done': done
            })
        
        total_moves += game.move_count
        
        # Assign rewards and train
        for i, t in enumerate(episode_transitions):
            if game.winner == t['player']:
                reward = 1.0
                if t['player'] == ConnectFour.PLAYER_1:
                    p1_wins += 1
                else:
                    p2_wins += 1
            elif game.winner is None:
                reward = 0.0
                if i == len(episode_transitions) - 1:
                    draws += 1
            else:
                reward = -1.0
            
            agent.store_transition(t['state'], t['action'], reward, t['next_state'], t['done'])
            loss = agent.train_step()
            if loss is not None:
                losses.append(loss)
        
        agent.episodes_trained += 1
        
        # Record stats
        if episode % print_every == 0:
            total_games = p1_wins + p2_wins + draws
            if total_games > 0:
                stats['episode'].append(episode)
                stats['p1_win_rate'].append(p1_wins / total_games * 100)
                stats['p2_win_rate'].append(p2_wins / total_games * 100)
                stats['draw_rate'].append(draws / total_games * 100)
                stats['epsilon'].append(agent.epsilon)
                stats['avg_game_length'].append(total_moves / total_games)
                stats['avg_loss'].append(np.mean(losses) if losses else 0)
                
                print(f"Episode {episode:6d} | "
                      f"P1: {stats['p1_win_rate'][-1]:5.1f}% | "
                      f"P2: {stats['p2_win_rate'][-1]:5.1f}% | "
                      f"Draw: {stats['draw_rate'][-1]:5.1f}% | "
                      f"Eps: {agent.epsilon:.3f} | "
                      f"Loss: {stats['avg_loss'][-1]:.4f}")
            
            # Reset counters
            p1_wins = p2_wins = draws = total_moves = 0
            losses = []
    
    return stats

In [None]:
# Train the agent
NUM_EPISODES = 20000
PRINT_EVERY = 500

print(f"Training for {NUM_EPISODES} episodes...\n")
stats = train_with_stats(agent, num_episodes=NUM_EPISODES, print_every=PRINT_EVERY)
print("\nTraining complete!")

## 3. Training Progress Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Win/Loss/Draw rates
ax1 = axes[0, 0]
ax1.plot(stats['episode'], stats['p1_win_rate'], label='Player 1 (Red)', color='red', alpha=0.8)
ax1.plot(stats['episode'], stats['p2_win_rate'], label='Player 2 (Yellow)', color='gold', alpha=0.8)
ax1.plot(stats['episode'], stats['draw_rate'], label='Draws', color='gray', alpha=0.8)
ax1.set_xlabel('Episode')
ax1.set_ylabel('Rate (%)')
ax1.set_title('Self-Play Outcomes Over Training')
ax1.legend()
ax1.set_ylim(0, 100)

# Epsilon decay
ax2 = axes[0, 1]
ax2.plot(stats['episode'], stats['epsilon'], color='blue', linewidth=2)
ax2.set_xlabel('Episode')
ax2.set_ylabel('Epsilon')
ax2.set_title('Exploration Rate (Epsilon) Decay')
ax2.set_ylim(0, 1)

# Average game length
ax3 = axes[1, 0]
ax3.plot(stats['episode'], stats['avg_game_length'], color='green', linewidth=2)
ax3.set_xlabel('Episode')
ax3.set_ylabel('Moves')
ax3.set_title('Average Game Length')
ax3.axhline(y=21, color='red', linestyle='--', alpha=0.5, label='Min possible (21 moves)')
ax3.legend()

# Training loss
ax4 = axes[1, 1]
ax4.plot(stats['episode'], stats['avg_loss'], color='purple', linewidth=2)
ax4.set_xlabel('Episode')
ax4.set_ylabel('Loss')
ax4.set_title('Average Training Loss')

plt.tight_layout()
plt.savefig('../training_progress.png', dpi=150)
plt.show()

## 4. Evaluate Against Random Opponent

In [None]:
def evaluate_vs_random(agent, num_games=1000):
    """Evaluate agent against random opponent."""
    random_agent = RandomAgent()
    results = {'as_p1': {'wins': 0, 'losses': 0, 'draws': 0},
               'as_p2': {'wins': 0, 'losses': 0, 'draws': 0}}
    
    for game_num in range(num_games):
        game = ConnectFour()
        agent_is_p1 = (game_num % 2 == 0)
        key = 'as_p1' if agent_is_p1 else 'as_p2'
        
        while not game.done:
            state = game.get_state()
            legal = game.get_legal_actions()
            
            if (game.current_player == ConnectFour.PLAYER_1) == agent_is_p1:
                action = agent.choose_action(state, legal, training=False)
            else:
                action = random_agent.choose_action(state, legal)
            
            game.make_move(action)
        
        if game.winner is None:
            results[key]['draws'] += 1
        elif (game.winner == ConnectFour.PLAYER_1) == agent_is_p1:
            results[key]['wins'] += 1
        else:
            results[key]['losses'] += 1
    
    return results

print("Evaluating against random opponent (1000 games)...")
eval_results = evaluate_vs_random(agent, 1000)

print("\nResults:")
for role in ['as_p1', 'as_p2']:
    r = eval_results[role]
    total = r['wins'] + r['losses'] + r['draws']
    role_name = "As Player 1 (Red)" if role == 'as_p1' else "As Player 2 (Yellow)"
    print(f"\n{role_name}:")
    print(f"  Wins:   {r['wins']/total*100:5.1f}% ({r['wins']}/{total})")
    print(f"  Losses: {r['losses']/total*100:5.1f}% ({r['losses']}/{total})")
    print(f"  Draws:  {r['draws']/total*100:5.1f}% ({r['draws']}/{total})")

## 5. Visualize Agent's Decision Making

In [None]:
def visualize_q_values(agent, game):
    """Visualize Q-values for current board state."""
    state = game.get_state()
    legal_actions = game.get_legal_actions()
    
    # Get Q-values from network
    with torch.no_grad():
        state_tensor = agent._preprocess_state(state, game.current_player)
        q_values = agent.policy_net(state_tensor).cpu().numpy()[0]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Board visualization
    board_display = np.zeros((6, 7, 3))
    for r in range(6):
        for c in range(7):
            if state[r, c] == 1:
                board_display[r, c] = [1, 0, 0]  # Red
            elif state[r, c] == 2:
                board_display[r, c] = [1, 1, 0]  # Yellow
            else:
                board_display[r, c] = [0.9, 0.9, 0.9]  # Empty
    
    ax1.imshow(board_display)
    ax1.set_xticks(range(7))
    ax1.set_yticks(range(6))
    ax1.set_title(f'Current Board (Player {game.current_player} to move)')
    ax1.grid(True)
    
    # Q-values bar chart
    colors = ['green' if i in legal_actions else 'red' for i in range(7)]
    ax2.bar(range(7), q_values, color=colors, alpha=0.7)
    ax2.set_xlabel('Column')
    ax2.set_ylabel('Q-Value')
    ax2.set_title('Q-Values by Column (green=legal, red=illegal)')
    ax2.set_xticks(range(7))
    
    # Highlight best legal action
    best_action = max(legal_actions, key=lambda a: q_values[a])
    ax2.axvline(x=best_action, color='blue', linestyle='--', linewidth=2, label=f'Best: col {best_action}')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

# Need torch import for visualization
import torch

# Create a sample game state
game = ConnectFour()
# Make some moves to create an interesting position
for col in [3, 3, 4, 4, 2]:
    if col in game.get_legal_actions():
        game.make_move(col)

print("Current board:")
print(game)
print()
visualize_q_values(agent, game)

## 6. Save Trained Agent

In [None]:
# Save the trained agent
agent.save('../trained_agent.pth')
print("Agent saved to ../trained_agent.pth")

# Print final stats
final_stats = agent.get_stats()
print(f"\nFinal Statistics:")
print(f"  Episodes trained: {final_stats['episodes']}")
print(f"  Training steps: {final_stats['training_steps']}")
print(f"  Final epsilon: {final_stats['epsilon']:.4f}")
print(f"  Replay buffer size: {final_stats['buffer_size']}")