Creating Connect-4 AIs using Reinforcement Learning

In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque

In [2]:
class Connect4Env:
    def __init__(self):
        self.rows = 6
        self.columns = 7
        self.board = np.zeros((self.rows, self.columns), dtype=int)
        self.current_player = 1

    def reset(self):
        self.board = np.zeros((self.rows, self.columns), dtype=int)
        self.current_player = 1
        return self.board

    def step(self, action):
        for row in range(self.rows - 1, -1, -1):
            if self.board[row, action] == 0:
                self.board[row, action] = self.current_player
                reward = 0
                done = False

                # Check win
                if self.check_win(self.current_player):
                    reward = 10  # Winning reward
                    done = True

                # Check draw
                elif np.all(self.board != 0):
                    reward = 0
                    done = True

                # Check if the move blocks the opponent
                elif self.check_blocked_opponent():
                    reward = 2  # Reward for blocking opponent

                # Check if the move allows the opponent to win
                elif self.allows_opponent_win():
                    reward = -5  # Penalty for allowing opponent to win

                # Switch player
                self.current_player *= -1
                return self.board, reward, done

        raise ValueError("Invalid action: Column is full.")

    def get_valid_actions(self):
        return [col for col in range(self.columns) if self.board[0, col] == 0]

    def check_win(self, player_id):
        # Horizontal
        for row in range(self.rows):
            for col in range(self.columns - 3):
                if all(self.board[row, col + i] == player_id for i in range(4)):
                    return True
        # Vertical
        for col in range(self.columns):
            for row in range(self.rows - 3):
                if all(self.board[row + i, col] == player_id for i in range(4)):
                    return True
        # Diagonal (top-left to bottom-right)
        for row in range(self.rows - 3):
            for col in range(self.columns - 3):
                if all(self.board[row + i, col + i] == player_id for i in range(4)):
                    return True
        # Diagonal (top-right to bottom-left)
        for row in range(self.rows - 3):
            for col in range(3, self.columns):
                if all(self.board[row + i, col - i] == player_id for i in range(4)):
                    return True
        return False

    def check_blocked_opponent(self):
        """Check if the current move blocks the opponent's potential 4-in-a-row."""
        opponent_id = -self.current_player
        for col in range(self.columns):
            for row in range(self.rows - 1, -1, -1):
                if self.board[row, col] == 0:
                    # Temporarily simulate the opponent's move
                    self.board[row, col] = opponent_id
                    if self.check_win(opponent_id):  # Check if opponent could win
                        self.board[row, col] = 0  # Revert the simulated move
                        return True
                    self.board[row, col] = 0  # Revert the simulated move
                    break
        return False

    def allows_opponent_win(self):
        """Check if the current move allows the opponent to win on their next turn."""
        opponent_id = -self.current_player
        for col in range(self.columns):
            for row in range(self.rows - 1, -1, -1):
                if self.board[row, col] == 0:
                    # Temporarily simulate the opponent's move
                    self.board[row, col] = opponent_id
                    if self.check_win(opponent_id):  # Opponent can win
                        self.board[row, col] = 0  # Revert the simulated move
                        return True
                    self.board[row, col] = 0  # Revert the simulated move
                    break
        return False


In [3]:
class Connect4Model(nn.Module):
    def __init__(self):
        super(Connect4Model, self).__init__()
        self.fc1 = nn.Linear(6 * 7, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 7)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))  # First hidden layer with tanh
        x = torch.tanh(self.fc2(x))  # Second hidden layer with tanh
        return self.fc3(x)  # Output layer (no activation, raw Q-values)

In [4]:
class DQNAgent:
    def __init__(self, lr=0.0001, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.model = Connect4Model()
        self.target_model = Connect4Model()
        self.target_model.load_state_dict(self.model.state_dict())  # Initialize target model
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.memory = deque(maxlen=10000)

    def act(self, state):
        """Choose an action using an epsilon-greedy policy, ensuring the action is valid."""
        valid_actions = env.get_valid_actions()  # Get valid columns
        if random.random() < self.epsilon:
            return random.choice(valid_actions)  # Random valid action
        else:
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            q_values = self.model(state)
            # Select the best action among valid actions
            best_action = max(valid_actions, key=lambda a: q_values[0, a].item())
            return best_action


    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return

        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.long)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.tensor(next_states, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        # Compute current Q-values
        q_values = self.model(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze()

        # Compute target Q-values
        next_q_values = self.target_model(next_states).max(1)[0]
        targets = rewards + self.gamma * next_q_values * (1 - dones)

        # Compute loss and optimize
        loss = F.mse_loss(q_values, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)


In [None]:
# Training the RL agent via self-play
env = Connect4Env()
agent = DQNAgent()
episodes = 10000
batch_size = 64

for episode in range(episodes):
    state = env.reset().flatten()
    total_reward = 0
    done = False

    while not done:
        action = agent.act(state)
        next_state, reward, done = env.step(action)
        next_state = next_state.flatten()
        agent.remember(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

        # Train the agent
        agent.replay(batch_size)

    agent.decay_epsilon()
    agent.update_target_model()

    if (episode + 1) % 100 == 0:
        print(f"Episode {episode + 1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")


Episode 100/10000, Total Reward: 10, Epsilon: 0.61
Episode 200/10000, Total Reward: 10, Epsilon: 0.37
Episode 300/10000, Total Reward: 10, Epsilon: 0.22
Episode 400/10000, Total Reward: 10, Epsilon: 0.13
Episode 500/10000, Total Reward: 10, Epsilon: 0.08
Episode 600/10000, Total Reward: 10, Epsilon: 0.05
Episode 700/10000, Total Reward: 10, Epsilon: 0.03
Episode 800/10000, Total Reward: 10, Epsilon: 0.02
Episode 900/10000, Total Reward: 10, Epsilon: 0.01
Episode 1000/10000, Total Reward: 10, Epsilon: 0.01
Episode 1100/10000, Total Reward: 10, Epsilon: 0.01
Episode 1200/10000, Total Reward: 0, Epsilon: 0.01
Episode 1300/10000, Total Reward: 10, Epsilon: 0.01
Episode 1400/10000, Total Reward: 10, Epsilon: 0.01
Episode 1500/10000, Total Reward: 10, Epsilon: 0.01
Episode 1600/10000, Total Reward: 10, Epsilon: 0.01
Episode 1700/10000, Total Reward: 10, Epsilon: 0.01
Episode 1800/10000, Total Reward: 10, Epsilon: 0.01
Episode 1900/10000, Total Reward: 10, Epsilon: 0.01
Episode 2000/10000, To

In [None]:
#Save the model
torch.save(agent.model.state_dict(), 'RL.pth')

In [None]:
#Save the agent
data = {
    'model_state_dict': agent.model.state_dict(),
    'target_model_state_dict': agent.target_model.state_dict(),
    'epsilon': agent.epsilon,
    'optimizer_state_dict': agent.optimizer.state_dict()
}
torch.save(data, 'RL_model.pth')

In [5]:
# Load the model
def load_model(model_class, filename):
    model = model_class()  # Initialize the model
    model.load_state_dict(torch.load(filename))
    model.eval()  # Set the model to evaluation mode
    print(f"Model loaded from {filename}")
    return model

# Function to play against the bot
def play_against_bot(model):
    env = Connect4Env()
    state = env.reset()
    done = False

    print("You are Player 1 (1), and the bot is Player -1 (-1).")
    while not done:
        print("\nCurrent Board:")
        print(env.board)

        # Human move
        if env.current_player == 1:
            valid_actions = env.get_valid_actions()
            action = -1
            while action not in valid_actions:
                try:
                    action = int(input(f"Choose a column (0-{env.columns - 1}): "))
                except ValueError:
                    print("Invalid input. Please enter a number between 0 and 6.")

            _, reward, done = env.step(action)
            if done:
                print("\nFinal Board:")
                print(env.board)
                if reward == 10:
                    print("You win!")
                elif reward == 0:
                    print("It's a draw!")
                break

        # Bot move
        else:
            state_flattened = torch.tensor(state.flatten(), dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                q_values = model(state_flattened)
                valid_actions = env.get_valid_actions()
                action = max(valid_actions, key=lambda a: q_values[0, a].item())
            print(f"Bot chooses column: {action}")
            state, reward, done = env.step(action)

            if done:
                print("\nFinal Board:")
                print(env.board)
                if reward == 10:
                    print("The bot wins!")
                elif reward == 0:
                    print("It's a draw!")
                break

In [8]:
model = load_model(Connect4Model, 'RL_model.pth')  # Load the trained model
play_against_bot(model)

  model.load_state_dict(torch.load(filename))


Model loaded from RL_model.pth
You are Player 1 (1), and the bot is Player -1 (-1).

Current Board:
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]]
Choose a column (0-6): 3

Current Board:
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]]
Bot chooses column: 5

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0 -1  0]]
Choose a column (0-6): 3

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0  0  0]
 [ 0  0  0  1  0 -1  0]]
Bot chooses column: 5

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  1  0 -1  0]
 [ 0  0  0  1  0 -1  0]]
Choose a column (0-6): 3

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  