In [1]:
import numpy as np
import torch

class ConnectFourEnv():
    def __init__(self) -> None:
        self.rows = 6
        self.cols = 7
        # 1D board of size 42
        self.board = np.zeros(shape=(self.rows * self.cols,), dtype=int)
        self.computer = 1
        self.opponent = -1
        self.reset()

    def reset(self):
        """Resets the board and handles who moves first."""
        self.board[:] = 0
        self.done = False
        self.winner = None
        
        # Randomly decide who goes first
        self.mover = np.random.choice([self.computer, self.opponent])
        
        # If opponent starts, they make a random move immediately
        if self.mover == self.opponent:
            action = self.random_action()
            self.apply_action(action, self.opponent)
            
        return self.board.copy()

    def available_actions_idx(self):
        """Returns a list of column indices (0-6) that are not full."""
        # Reshape to 2D to easily check the top row (row 0)
        board_2d = self.board.reshape(self.rows, self.cols)
        # If the top row (0) at column c is 0, the column is valid
        return [c for c in range(self.cols) if board_2d[0, c] == 0]

    def random_action(self):
        """Returns a random valid column."""
        possible_cols = self.available_actions_idx()
        if not possible_cols:
            return None # Draw/Full
        return np.random.choice(possible_cols)

    def apply_action(self, col_idx, player):
        """
        Simulates gravity: places the player's piece in the 
        lowest available row in the given column.
        """
        board_2d = self.board.reshape(self.rows, self.cols)
        
        # Find the lowest empty row in this column
        # We scan from bottom (row 5) to top (row 0)
        for r in range(self.rows - 1, -1, -1):
            if board_2d[r, col_idx] == 0:
                board_2d[r, col_idx] = player
                break
        
        # Flatten back to 1D to update self.board
        self.board = board_2d.flatten()

    def check_win(self, player):
        """Checks horizontal, vertical, and diagonal lines for 4 connected."""
        board_2d = self.board.reshape(self.rows, self.cols)

        # 1. Horizontal
        for r in range(self.rows):
            for c in range(self.cols - 3):
                if np.all(board_2d[r, c:c+4] == player):
                    return True

        # 2. Vertical
        for r in range(self.rows - 3):
            for c in range(self.cols):
                if np.all(board_2d[r:r+4, c] == player):
                    return True

        # 3. Diagonal (\)
        for r in range(self.rows - 3):
            for c in range(self.cols - 3):
                if np.all([board_2d[r+i, c+i] == player for i in range(4)]):
                    return True

        # 4. Anti-Diagonal (/)
        for r in range(3, self.rows):
            for c in range(self.cols - 3):
                if np.all([board_2d[r-i, c+i] == player for i in range(4)]):
                    return True

        return False
    
    def step(self, action, opponent_model=None): # <--- Check this argument
        # 1. Check Agent Valid Move
        if action not in self.available_actions_idx():
             return self.board.copy(), -10, True, {"result": "Error"}
        
        # 2. Agent Move
        self.apply_action(action, self.computer)
        if self.check_win(self.computer):
            return self.board.copy(), 1, True, {"result": "Win"}
        if len(self.available_actions_idx()) == 0:
            return self.board.copy(), 0, True, {"result": "Draw"}

        # 3. Opponent Move
        if opponent_model is None:
            # Default: Random
            opp_action = self.random_action()
        else:
            # Advanced: The Clone
            opp_action = self.get_opponent_action(opponent_model) # <--- Make sure this is called
            
        self.apply_action(opp_action, self.opponent)

        if self.check_win(self.opponent):
            return self.board.copy(), -1, True, {"result": "Loss"}
        if len(self.available_actions_idx()) == 0:
            return self.board.copy(), 0, True, {"result": "Draw"}

        return self.board.copy(), 0, False, {}

    def get_opponent_action(self, model):
        board_for_opp = self.board * -1 
        state_t = torch.tensor(board_for_opp, dtype=torch.float32).unsqueeze(0).view(1, 1, 6, 7) # Ensure shape
        
        with torch.no_grad():
            q_vals = model(state_t)
            valid_moves = self.available_actions_idx()
            mask = torch.full_like(q_vals, -float('inf'))
            mask[0, valid_moves] = q_vals[0, valid_moves]
            action = mask.max(1)[1].item()
        return action

    def render(self):
        """Visualizes the board."""
        board_2d = self.board.reshape(self.rows, self.cols)
        symbols = {0: '.', 1: 'X', -1: 'O'}
        print("\nBoard State:")
        for row in board_2d:
            print(" ".join([symbols[x] for x in row]))
        print("-" * 13)
        print("0 1 2 3 4 5 6\n")

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class QNConnectFour(nn.Module):
    def __init__(self, output_dim=7):
        super(QNConnectFour, self).__init__()
        
        # --- Convolutional Block ---
        # We treat the board as an image: 1 channel (the values -1, 0, 1), 6 rows, 7 cols
        
        # Conv1: Expands features. looks for small local patterns
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64) # Normalization helps faster convergence
        
        # Conv2: Goes deeper
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        # Conv3: Refines features
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        # --- Fully Connected Block ---
        # Flatten: 128 channels * 6 rows * 7 cols = 5376
        self.fc1 = nn.Linear(128 * 6 * 7, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, output_dim) # Output is 7 (one Q-value per column)

    def forward(self, x):
        # 1. Reshape Input
        # The environment gives us a flat vector (Batch, 42).
        # We must reshape it to (Batch, 1, 6, 7) for the CNN.
        x = x.view(-1, 1, 6, 7) 
        
        # 2. Convolutions + Activations
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # 3. Flatten
        x = x.view(x.size(0), -1)
        
        # 4. Dense Layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        # 5. Output (No activation here, raw Q-values)
        actions = self.fc3(x)
        
        return actions

In [3]:
import random
from collections import deque, namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
import torch.optim as optim
import math

class DQNAgent:
    def __init__(self, input_dim, output_dim):
        self.output_dim = output_dim
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # 1. Initialize Networks
        # Policy Net: The one we train
        self.policy_net = QNConnectFour(output_dim).to(self.device)
        # Target Net: A stable copy to calculate future rewards (stabilizes training)
        self.target_net = QNConnectFour(output_dim).to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval() # Set to evaluation mode

        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.0001)
        self.memory = ReplayMemory(10000)

        # Hyperparameters
        self.BATCH_SIZE = 64
        self.GAMMA = 0.99  # Discount factor (cares about long term)
        self.EPS_START = 1.0
        self.EPS_END = 0.05
        self.EPS_DECAY = 1000 # How fast exploration decays
        self.steps_done = 0

    def select_action(self, state, valid_moves):
        """
        Epsilon-Greedy strategy with invalid move masking.
        """
        sample = random.random()
        eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * \
            math.exp(-1. * self.steps_done / self.EPS_DECAY)
        self.steps_done += 1

        # EXPLORATION: Pick random valid move
        if sample < eps_threshold:
            return torch.tensor([[random.choice(valid_moves)]], device=self.device, dtype=torch.long)
        
        # EXPLOITATION: Pick best move from Network
        with torch.no_grad():
            # Get Q-values from network
            q_values = self.policy_net(state.to(self.device))
            
            # Mask invalid moves: Set their Q-value to negative infinity so they aren't picked
            # Create a mask of -inf
            mask = torch.full_like(q_values, -float('inf'))
            # Set valid indices to the actual q_values
            mask[0, valid_moves] = q_values[0, valid_moves]
            
            # Return index of max value
            return mask.max(1)[1].view(1, 1)

    def optimize_model(self):
        if len(self.memory) < self.BATCH_SIZE:
            return

        transitions = self.memory.sample(self.BATCH_SIZE)
        batch = Transition(*zip(*transitions))

        # Convert batch data to tensors
        state_batch = torch.cat(batch.state).to(self.device)
        action_batch = torch.cat(batch.action).to(self.device)
        reward_batch = torch.cat(batch.reward).to(self.device)
        next_state_batch = torch.cat(batch.next_state).to(self.device)
        done_batch = torch.cat(batch.done).to(self.device)

        # 1. Compute Q(s_t, a) - The Q-values we estimated
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)

        # 2. Compute V(s_{t+1}) for all next states using Target Net
        next_state_values = self.target_net(next_state_batch).max(1)[0].detach()
        
        # 3. Compute the expected Q values (Bellman Equation)
        # If done, expected_q is just reward. If not, reward + gamma * best_future_q
        expected_state_action_values = reward_batch + (self.GAMMA * next_state_values * (1 - done_batch))

        # 4. Compute Huber Loss (Smooth L1)
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # 5. Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        
        # Clip gradients to prevent exploding gradients (common in RL)
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

        return loss

In [7]:
import copy
import math
import torch

def training(env, agent, num_episodes=1000, target_update_freq=10):

    # Initialize
    opponent_net = None 
    win_history = []
    loss_history = [] # <--- NEW: To store average loss per episode
    win_rate_threshold = 0.85 

    print("Starting Training...")

    for i_episode in range(num_episodes):
        state_np = env.reset()
        state = torch.tensor(state_np, dtype=torch.float32).unsqueeze(0)
        
        total_reward = 0
        done = False
        
        # <--- NEW: Variables to track loss within this specific episode
        episode_loss_sum = 0
        episode_opt_count = 0 
        
        while not done:
            # 1. Select Action
            valid_moves = env.available_actions_idx()
            action_tensor = agent.select_action(state, valid_moves)
            action = action_tensor.item() 
            
            # 2. Step Environment
            next_state_np, reward, done, info = env.step(action, opponent_model=opponent_net)
            
            # 3. Process Reward & Next State
            reward_tensor = torch.tensor([reward], dtype=torch.float32)
            next_state = torch.tensor(next_state_np, dtype=torch.float32).unsqueeze(0)
            done_tensor = torch.tensor([float(done)], dtype=torch.float32)

            # 4. Store in Memory
            agent.memory.push(state, action_tensor, reward_tensor, next_state, done_tensor)

            # 5. Move to next state
            state = next_state
            total_reward += reward

            # 6. Perform one step of optimization
            loss = agent.optimize_model()
            
            # <--- NEW: Accumulate loss
            if loss is not None:
                episode_loss_sum += loss.item() # .item() is crucial to save memory!
                episode_opt_count += 1

        # --- TRACKING LOSS --- # <--- NEW
        if episode_opt_count > 0:
            avg_ep_loss = episode_loss_sum / episode_opt_count
            loss_history.append(avg_ep_loss)
        else:
            loss_history.append(0)

        # --- TRACKING WINS ---
        if info['result'] == 'Win':
            win_history.append(1)
        else:
            win_history.append(0)
            
        # Keep only last 100 games
        if len(win_history) > 100: win_history.pop(0)
        if len(loss_history) > 100: loss_history.pop(0) # Keep loss history same size
            
        # --- THE UPDATE CHECK ---
        if i_episode % 50 == 0 and len(win_history) == 100:
            win_rate = sum(win_history) / 100
            # <--- NEW: Calculate average loss over the last 100 episodes
            avg_loss_stat = sum(loss_history) / len(loss_history)
            
            # Calculate epsilon for display
            curr_eps = agent.EPS_END + (agent.EPS_START - agent.EPS_END) * math.exp(-1. * agent.steps_done / agent.EPS_DECAY)
            
            print(f"Episode {i_episode} | Win Rate: {win_rate:.2f} | Avg Loss: {avg_loss_stat:.6f} | Epsilon: {curr_eps:.4f}")
            
            if win_rate > win_rate_threshold:
                print(f"ðŸš€ PROMOTION! Agent (Win Rate {win_rate:.2f}) is now the Opponent.")
                
                opponent_net = copy.deepcopy(agent.policy_net)
                opponent_net.eval()
                
                win_history = [] 
                loss_history = [] # Optional: Reset loss history too if you want fresh stats
                
                agent.steps_done = int(agent.EPS_DECAY * 2)
    
    return agent

In [8]:
# Initialize Environment and Agent
environment = ConnectFourEnv()
agent = DQNAgent(input_dim=42, output_dim=7)

episodes = 10000
update_freq = 10

model = training(env=environment, agent=agent, num_episodes=episodes, target_update_freq=update_freq)

# Save the policy network (the one that plays the game)
torch.save(model.policy_net.state_dict(), "connect4_dqn.pth")
print("Model saved!")

Starting Training...
Episode 100 | Win Rate: 0.57 | Avg Loss: 0.003333 | Epsilon: 0.4123
Episode 150 | Win Rate: 0.62 | Avg Loss: 0.002755 | Epsilon: 0.2845
Episode 200 | Win Rate: 0.70 | Avg Loss: 0.002514 | Epsilon: 0.2009
Episode 250 | Win Rate: 0.76 | Avg Loss: 0.002468 | Epsilon: 0.1484
Episode 300 | Win Rate: 0.76 | Avg Loss: 0.002277 | Epsilon: 0.1170
Episode 350 | Win Rate: 0.77 | Avg Loss: 0.001937 | Epsilon: 0.0949
Episode 400 | Win Rate: 0.81 | Avg Loss: 0.001880 | Epsilon: 0.0799
Episode 450 | Win Rate: 0.78 | Avg Loss: 0.001874 | Epsilon: 0.0704
Episode 500 | Win Rate: 0.78 | Avg Loss: 0.001919 | Epsilon: 0.0643
Episode 550 | Win Rate: 0.82 | Avg Loss: 0.001963 | Epsilon: 0.0599
Episode 600 | Win Rate: 0.83 | Avg Loss: 0.001720 | Epsilon: 0.0567
Episode 650 | Win Rate: 0.80 | Avg Loss: 0.001698 | Epsilon: 0.0547
Episode 700 | Win Rate: 0.77 | Avg Loss: 0.001849 | Epsilon: 0.0532
Episode 750 | Win Rate: 0.76 | Avg Loss: 0.001733 | Epsilon: 0.0522
Episode 800 | Win Rate: 0.7