

# Install dependencies


In [None]:
!pip install torch numpy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random


# Create the ***Four in a Row*** Environment

In [2]:
# number of rows and columns in the game
ROWS = 6
COLS = 7

class ConnectFourEnv:
    def __init__(self):
        self.board = np.zeros((ROWS, COLS), dtype=int)  # 6x7 grid with zeros.
        self.done = False   # is game over?
        self.winner = None  # stores winner (AI:1, Human:2)

    def reset(self):    # reset game before new game
        self.board = np.zeros((ROWS, COLS), dtype=int)
        self.done = False
        self.winner = None
        return self.board

    def step(self, col, player):    # drop piece into selected column
        if self.done or self.board[0][col] != 0:    # column is full or game over
            return self.board, -10, True  # invalid move = penalty

        for row in range(ROWS-1, -1, -1):
            if self.board[row][col] == 0:
                self.board[row][col] = player
                break

        if self.check_win(player):  # check if the move wins the game
            self.done = True
            self.winner = player
            return self.board, 1 if player == 1 else -1, True   # rewarding/punishment

        if not any(0 in row for row in self.board): # if the board is full, game over with draw
            self.done = True
            return self.board, 0, True  # no reward/punishment

        return self.board, 0, False # otherwise, continue game

    def check_win(self, player):
        # horizontal check
        for r in range(ROWS):
            for c in range(COLS-3):
                if all(self.board[r, c+i] == player for i in range(4)):
                    return True

        # vertical check
        for r in range(ROWS-3):
            for c in range(COLS):
                if all(self.board[r+i, c] == player for i in range(4)):
                    return True

        # diagonal check
        for r in range(ROWS-3):
            for c in range(COLS-3):
                if all(self.board[r+i, c+i] == player for i in range(4)) or \
                   all(self.board[r+3-i, c+i] == player for i in range(4)):
                    return True
        return False


# CNN-Based Deep Q-Network (DQN)

In [3]:
class CNN_DQN(nn.Module):
    def __init__(self):
        super(CNN_DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 6 * 7, 128)
        self.fc2 = nn.Linear(128, COLS)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


# Experience Replay Memory

In [4]:
class ReplayMemory:
    # store 10.000 experience
    def __init__(self, capacity=10000):
        self.memory = []
        self.capacity = capacity

    def push(self, experience):
        if len(self.memory) > self.capacity:
            self.memory.pop(0)
        self.memory.append(experience)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


# Train the CNN-DQN Model

In [None]:
def train_cnn_dqn():
    env = ConnectFourEnv()
    dqn = CNN_DQN()
    optimizer = optim.Adam(dqn.parameters(), lr=0.001)
    memory = ReplayMemory()

    episodes = 5000
    gamma = 0.99
    epsilon = 1.0
    epsilon_decay = 0.9995
    epsilon_min = 0.1
    batch_size = 64

    for episode in range(episodes):
        state = env.reset()
        state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        done = False

        while not done:
            if random.random() < epsilon:
                action = random.choice([c for c in range(COLS) if env.board[0][c] == 0])
            else:
                with torch.no_grad():
                    action = torch.argmax(dqn(state)).item()  # Exploit

            next_state, reward, done = env.step(action, 1)
            next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            memory.push((state, action, reward, next_state, done))
            state = next_state

            if len(memory) > batch_size:
                batch = memory.sample(batch_size)
                states, actions, rewards, next_states, dones = zip(*batch)

                states = torch.cat(states)
                actions = torch.tensor(actions).unsqueeze(1)
                rewards = torch.tensor(rewards, dtype=torch.float32)
                next_states = torch.cat(next_states)
                dones = torch.tensor(dones, dtype=torch.float32)

                q_values = dqn(states).gather(1, actions).squeeze()
                next_q_values = dqn(next_states).max(1)[0].detach()
                target_q_values = rewards + gamma * next_q_values * (1 - dones)

                loss = F.mse_loss(q_values, target_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 500 == 0:
            print(f"Episode {episode}, Epsilon: {epsilon:.2f}")

    torch.save(dqn.state_dict(), "cnn_dqn.pth")
    print("Training complete, model saved.")

train_cnn_dqn()


In [None]:
def play_cnn_dqn():
    env = ConnectFourEnv()
    dqn = CNN_DQN()
    dqn.load_state_dict(torch.load("cnn_dqn.pth"))
    dqn.eval()

    state = env.reset()
    print("You are Player 2 (O), AI is Player 1 (X)")

    while not env.done:
        print(np.flip(env.board, 0))
        move = int(input("Input your piece to column (0-6): "))
        env.step(move, 2)

        if not env.done:
            state_tensor = torch.tensor(env.board, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            ai_move = torch.argmax(dqn(state_tensor)).item()
            env.step(ai_move, 1)
            print(f"AI played column {ai_move}")

    print(np.flip(env.board, 0))
    print("Game Over!")

play_cnn_dqn()
