# Connect Four AI via Minimax Bellman Equation

This notebook demonstrates adversarial game solving through the **minimax Bellman equation**:

$$V(s) = \begin{cases}
\max_a V(\text{apply}(s, a)) & \text{if Max's turn} \\
\min_a V(\text{apply}(s, a)) & \text{if Min's turn}
\end{cases}$$

We use **alpha-beta pruning** to efficiently search the game tree for Connect Four,
a 7x6 board game where the first player to connect four pieces in a line wins.

In [None]:
import bellmaneq
from bellmaneq.viz import plot_connect_four, plot_tictactoe
import matplotlib.pyplot as plt
import time

%matplotlib inline
plt.rcParams['figure.dpi'] = 120

## Warm-up: Tic-Tac-Toe is a Solved Game

With perfect play from both sides, Tic-Tac-Toe always ends in a draw.
Our minimax solver confirms this with exhaustive search (depth=9).

In [None]:
ttt = bellmaneq.TicTacToe()
board = [0] * 9

value = ttt.minimax(board, 1, 9)
print(f'Tic-Tac-Toe game-theoretic value: {value}')
print('(0.0 = draw with perfect play from both sides)')

# Visualize the value of each opening move
move_values = []
for i in range(9):
    b = board.copy()
    b[i] = 1
    v = -ttt.minimax(b, 2, 8)  # negate for X's perspective
    move_values.append(v)

fig = plot_tictactoe(board, values=move_values, title='Opening Move Values (X\'s perspective)')
plt.show()

## Connect Four: Search Depth vs Strength

Unlike Tic-Tac-Toe, Connect Four has a much larger game tree (~4.5 trillion nodes).
We can't solve it exhaustively in real time, but deeper search yields stronger play.

In [None]:
game = bellmaneq.ConnectFour()
board = bellmaneq.ConnectFour.empty_board()

# Measure search time at different depths
for depth in range(1, 9):
    start = time.time()
    move = game.best_move(board, 1, depth)
    elapsed = time.time() - start
    value = game.minimax(board, 1, depth)
    print(f'Depth {depth}: best_move={move}, value={value:+.1f}, time={elapsed:.3f}s')

## AI vs AI Match

Let's pit two AIs against each other at different search depths
and visualize the game as it unfolds.

In [None]:
def play_game(depth_p1=6, depth_p2=4):
    """Play a full game between two AIs and return the move history."""
    game = bellmaneq.ConnectFour()
    board = bellmaneq.ConnectFour.empty_board()
    player = 1
    history = []
    
    for move_num in range(42):  # max possible moves
        depth = depth_p1 if player == 1 else depth_p2
        col = game.best_move(board, player, depth)
        if col is None:
            break
        
        board = game.apply_move(board, player, col)
        history.append((board, player, col))
        
        winner = game.check_winner(board)
        if winner != 0:
            break
        
        player = 2 if player == 1 else 1
    
    return history, winner

history, winner = play_game(depth_p1=6, depth_p2=4)
result_str = {0: 'ongoing', 1: 'Player 1 (Red) wins', 2: 'Player 2 (Yellow) wins', 3: 'Draw'}
print(f'Game result: {result_str[winner]} after {len(history)} moves')

In [None]:
# Show selected positions from the game
key_moves = [0, len(history)//4, len(history)//2, 3*len(history)//4, len(history)-1]
key_moves = sorted(set(min(m, len(history)-1) for m in key_moves))

fig, axes = plt.subplots(1, len(key_moves), figsize=(5 * len(key_moves), 5))
if len(key_moves) == 1:
    axes = [axes]

for ax, idx in zip(axes, key_moves):
    board_state, player, col = history[idx]
    # Find the row where the piece landed
    row = None
    for r in range(6):
        if board_state[r][col] != 0:
            row = r
            break
    last = (row, col) if row is not None else None
    plot_connect_four(
        board_state, 
        title=f'Move {idx+1} (P{player} col {col})',
        last_move=last, 
        ax=ax
    )

plt.suptitle(f'Game Replay: Depth 6 vs Depth 4 â€” {result_str[winner]}', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## Depth Advantage Analysis

How much does an extra ply of search depth matter? Let's run a tournament.

In [None]:
results = {}
depths = [2, 4, 6]

for d1 in depths:
    for d2 in depths:
        _, w = play_game(d1, d2)
        results[(d1, d2)] = w
        status = result_str[w]
        print(f'Depth {d1} vs Depth {d2}: {status}')

print('\nTournament matrix (winner):')
header = '     ' + '  '.join(f'd={d}' for d in depths)
print(header)
for d1 in depths:
    row = f'd={d1}  '
    for d2 in depths:
        w = results[(d1, d2)]
        if w == 1:
            row += ' P1  '
        elif w == 2:
            row += ' P2  '
        elif w == 3:
            row += '  D  '
        else:
            row += '  ?  '
    print(row)