In [1]:
from kaggle_environments import evaluate, make, utils

termcolor not installed, skipping dependency


In [22]:
env = make("connectx", debug=True)

In [23]:
env.render(mode="ipython", width=500, height=450)

In [25]:
trainer = env.train([None, "random"])

In [26]:
obs, reward, done, info = trainer.step(0)

In [28]:
env.render(mode="ipython", width=500, height=450)

In [34]:
import random

In [58]:
env.play([None,'negamax'])

In [43]:
config = {'rows': 6, 'columns': 7, 'inarow': 4}


In [51]:
outcomes = evaluate("connectx", [agent_random, 'random'], config, [], 10)

In [74]:
# rule based
def my_agent(observation, configuration):
    import numpy as np
    
    # Convert the board to a 2D numpy array
    def get_board(observation, configuration):
        board = np.array(observation.board).reshape(configuration.rows, configuration.columns)
        return board
    
    # Check if a position is within the board
    def is_valid_position(row, col, rows, cols):
        return 0 <= row < rows and 0 <= col < cols
    
    # Check if placing a piece in a column would result in a win
    def check_winning_move(board, col, player, configuration):
        rows, cols = configuration.rows, configuration.columns
        inarow = configuration.inarow
        
        # Find the row where the piece would land
        for row in range(rows-1, -1, -1):
            if board[row][col] == 0:
                # Temporarily place the piece
                board[row][col] = player
                
                # Check horizontal
                for c in range(max(0, col-inarow+1), min(cols-inarow+1, col+1)):
                    if all(board[row][c+i] == player for i in range(inarow)):
                        board[row][col] = 0  # Reset the position
                        return True
                
                # Check vertical
                if row <= rows-inarow:
                    if all(board[row+i][col] == player for i in range(inarow)):
                        board[row][col] = 0
                        return True
                
                # Check diagonal up-right
                for r, c in zip(range(row-inarow+1, row+1), range(col-inarow+1, col+1)):
                    if (is_valid_position(r, c, rows, cols) and 
                        is_valid_position(r+inarow-1, c+inarow-1, rows, cols)):
                        if all(board[r+i][c+i] == player for i in range(inarow)):
                            board[row][col] = 0
                            return True
                
                # Check diagonal up-left
                for r, c in zip(range(row-inarow+1, row+1), range(col+inarow-1, col-1, -1)):
                    if (is_valid_position(r, c, rows, cols) and 
                        is_valid_position(r+inarow-1, c-inarow+1, rows, cols)):
                        if all(board[r+i][c-i] == player for i in range(inarow)):
                            board[row][col] = 0
                            return True
                
                # Reset the position
                board[row][col] = 0
                break
        return False
    
    # Check if a column is playable
    def is_valid_move(board, col):
        return board[0][col] == 0
    
    # Get current board state
    board = get_board(observation, configuration)
    
    # Get valid moves
    valid_moves = [col for col in range(configuration.columns) if is_valid_move(board, col)]
    if not valid_moves:
        return 3  # Return center column if no valid moves (shouldn't happen)
    
    # Player numbers
    player = observation.mark
    opponent = 1 if player == 2 else 2
    
    # 1. Check for winning move
    for col in valid_moves:
        if check_winning_move(board.copy(), col, player, configuration):
            return col
    
    # 2. Block opponent's winning move
    for col in valid_moves:
        if check_winning_move(board.copy(), col, opponent, configuration):
            return col
    
    # 3. Prefer center and nearby columns
    center_col = configuration.columns // 2
    preferred_cols = sorted(valid_moves, key=lambda x: abs(x - center_col))
    
    return preferred_cols[0]

In [78]:
# optimized negamax
def my_agent(observation, configuration):
    import numpy as np
    from time import time
    
    EMPTY = 0
    MAX_DEPTH = 4  # Adjust based on performance needs
    TIMEOUT = 0.5  # Maximum thinking time in seconds
    
    def get_board(observation, configuration):
        return np.array(observation.board).reshape(configuration.rows, configuration.columns)
    
    def get_valid_moves(board):
        return [col for col in range(len(board[0])) if board[0][col] == EMPTY]
    
    def drop_piece(board, col, player):
        row = len(board) - 1
        while row >= 0 and board[row][col] != EMPTY:
            row -= 1
        board[row][col] = player
        return row
    
    def undo_move(board, col, row):
        board[row][col] = EMPTY
    
    def check_window(window, player, inarow):
        """Score a window of pieces"""
        opponent = 1 if player == 2 else 2
        score = 0
        
        player_count = np.count_nonzero(window == player)
        empty_count = np.count_nonzero(window == EMPTY)
        opponent_count = np.count_nonzero(window == opponent)
        
        # Winning window
        if player_count == inarow:
            score += 100000
        # Near winning windows
        elif player_count == inarow - 1 and empty_count == 1:
            score += 1000
        elif player_count == inarow - 2 and empty_count == 2:
            score += 100
        # Blocking opponent's potential wins
        if opponent_count == inarow - 1 and empty_count == 1:
            score -= 800
            
        return score
    
    def evaluate_position(board, player, configuration):
        """Evaluate the strength of the current position"""
        score = 0
        rows, cols = configuration.rows, configuration.columns
        inarow = configuration.inarow
        
        # Horizontal windows
        for row in range(rows):
            for col in range(cols - inarow + 1):
                window = board[row, col:col + inarow]
                score += check_window(window, player, inarow)
        
        # Vertical windows
        for row in range(rows - inarow + 1):
            for col in range(cols):
                window = board[row:row + inarow, col]
                score += check_window(window, player, inarow)
        
        # Diagonal windows (positive slope)
        for row in range(rows - inarow + 1):
            for col in range(cols - inarow + 1):
                window = [board[row + i][col + i] for i in range(inarow)]
                score += check_window(window, player, inarow)
        
        # Diagonal windows (negative slope)
        for row in range(inarow - 1, rows):
            for col in range(cols - inarow + 1):
                window = [board[row - i][col + i] for i in range(inarow)]
                score += check_window(window, player, inarow)
        
        # Prefer center columns
        center_array = board[:, cols//2]
        center_count = np.count_nonzero(center_array == player)
        score += center_count * 100
        
        return score

    def is_terminal(board, last_move, last_row, configuration):
        """Check if the game is over"""
        if last_move is None:
            return False
            
        rows, cols = configuration.rows, configuration.columns
        inarow = configuration.inarow
        player = board[last_row][last_move]
        
        # Horizontal
        count = 0
        for c in range(max(0, last_move - inarow + 1), min(cols, last_move + inarow)):
            if board[last_row][c] == player:
                count += 1
                if count == inarow:
                    return True
            else:
                count = 0
        
        # Vertical
        count = 0
        for r in range(max(0, last_row - inarow + 1), min(rows, last_row + inarow)):
            if board[r][last_move] == player:
                count += 1
                if count == inarow:
                    return True
            else:
                count = 0
        
        # Diagonal (positive slope)
        count = 0
        for i in range(-inarow + 1, inarow):
            r = last_row - i
            c = last_move + i
            if 0 <= r < rows and 0 <= c < cols:
                if board[r][c] == player:
                    count += 1
                    if count == inarow:
                        return True
                else:
                    count = 0
        
        # Diagonal (negative slope)
        count = 0
        for i in range(-inarow + 1, inarow):
            r = last_row + i
            c = last_move + i
            if 0 <= r < rows and 0 <= c < cols:
                if board[r][c] == player:
                    count += 1
                    if count == inarow:
                        return True
                else:
                    count = 0
        
        return False
    
    def negamax(board, depth, alpha, beta, player, last_move, last_row, configuration, start_time):
        """Negamax algorithm with alpha-beta pruning"""
        # Check timeout
        if time() - start_time > TIMEOUT:
            return None, -1000000
            
        # Check if terminal
        if last_move is not None and is_terminal(board, last_move, last_row, configuration):
            return None, -1000000
            
        # Check if maximum depth reached or no valid moves
        valid_moves = get_valid_moves(board)
        if depth == 0 or not valid_moves:
            return None, evaluate_position(board, player, configuration)
        
        # Order moves (center first)
        center = configuration.columns // 2
        valid_moves.sort(key=lambda x: -abs(x - center))
        
        max_score = -float('inf')
        best_move = valid_moves[0]
        
        # Try each valid move
        for move in valid_moves:
            row = drop_piece(board, move, player)
            next_player = 1 if player == 2 else 2
            
            # Recursively evaluate position
            _, score = negamax(
                board, depth-1, -beta, -alpha, 
                next_player, move, row, configuration, start_time
            )
            score = -score
            
            undo_move(board, move, row)
            
            if score > max_score:
                max_score = score
                best_move = move
            
            alpha = max(alpha, score)
            if alpha >= beta:
                break
                
        return best_move, max_score
    
    # Main agent logic
    board = get_board(observation, configuration)
    valid_moves = get_valid_moves(board)
    
    # If only one move is valid, make it
    if len(valid_moves) == 1:
        return valid_moves[0]
    
    # Use negamax to find best move
    start_time = time()
    best_move, _ = negamax(
        board, MAX_DEPTH, -float('inf'), float('inf'),
        observation.mark, None, None, configuration, start_time
    )
    
    return best_move if best_move is not None else valid_moves[0]

In [102]:
class ConnectFourNet(nn.Module):
    def __init__(self, rows=6, cols=7):
        super(ConnectFourNet, self).__init__()
        self.rows = rows
        self.cols = cols
        
        # Shared layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Calculate the flattened size
        with torch.no_grad():
            sample_input = torch.zeros(1, 3, rows, cols)
            x = F.relu(self.bn1(self.conv1(sample_input)))
            x = F.relu(self.bn2(self.conv2(x)))
            x = F.relu(self.bn3(self.conv3(x)))
            self.flattened_size = x.numel() // x.size(0)
        
        # Value head
        self.value_conv = nn.Conv2d(128, 32, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(32)
        self.value_fc1 = nn.Linear(32 * rows * cols, 256)
        self.value_fc2 = nn.Linear(256, cols)
        
        # Policy head
        self.policy_conv = nn.Conv2d(128, 32, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_fc = nn.Linear(32 * rows * cols, cols)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Shared layers
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Value head
        value = F.relu(self.value_bn(self.value_conv(x)))
        value = value.view(value.size(0), -1)  # Flatten properly
        value = F.relu(self.value_fc1(value))
        value = self.value_fc2(value)
        
        # Policy head
        policy = F.relu(self.policy_bn(self.policy_conv(x)))
        policy = policy.view(policy.size(0), -1)  # Flatten properly
        policy = self.policy_fc(policy)
        policy = F.softmax(policy, dim=1)
        
        return value, policy


In [103]:
# improved DQN
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from collections import deque


class ImprovedDQNAgent:
    def __init__(self, rows=6, cols=7, learning_rate=0.001, gamma=0.99,
                 epsilon=1.0, epsilon_min=0.1, epsilon_decay=0.995,
                 memory_size=50000, batch_size=64, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.rows = rows
        self.cols = cols
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.device = device
        
        self.network = ConnectFourNet(rows, cols).to(device)
        self.target_network = ConnectFourNet(rows, cols).to(device)
        self.target_network.load_state_dict(self.network.state_dict())
        
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        self.memory = deque(maxlen=memory_size)
        
    def evaluate_window(self, window, player):
        """Score a window of pieces"""
        opponent = 1 if player == 2 else 2
        score = 0
        
        player_count = np.count_nonzero(window == player)
        empty_count = np.count_nonzero(window == 0)
        opponent_count = np.count_nonzero(window == opponent)
        
        if player_count == 4:
            score += 1000
        elif player_count == 3 and empty_count == 1:
            score += 5
        elif player_count == 2 and empty_count == 2:
            score += 2
            
        if opponent_count == 3 and empty_count == 1:
            score -= 4
            
        return score
    
    def evaluate_position(self, board, player):
        """Calculate intermediate reward based on board position"""
        score = 0
        board_2d = board.reshape(self.rows, self.cols)
        
        # Horizontal windows
        for row in range(self.rows):
            for col in range(self.cols - 3):
                window = board_2d[row, col:col + 4]
                score += self.evaluate_window(window, player)
        
        # Vertical windows
        for row in range(self.rows - 3):
            for col in range(self.cols):
                window = board_2d[row:row + 4, col]
                score += self.evaluate_window(window, player)
        
        # Diagonal windows (positive slope)
        for row in range(self.rows - 3):
            for col in range(self.cols - 3):
                window = [board_2d[row + i][col + i] for i in range(4)]
                score += self.evaluate_window(window, player)
        
        # Diagonal windows (negative slope)
        for row in range(3, self.rows):
            for col in range(self.cols - 3):
                window = [board_2d[row - i][col + i] for i in range(4)]
                score += self.evaluate_window(window, player)
        
        # Center control preference
        center_array = board_2d[:, self.cols//2]
        center_count = np.count_nonzero(center_array == player)
        score += center_count * 3
        
        return score / 100.0  # Normalize score
    
    def get_state(self, board):
        state = np.zeros((3, self.rows, self.cols), dtype=np.float32)
        board_2d = board.reshape(self.rows, self.cols)
        
        state[0] = (board_2d == 1)
        state[1] = (board_2d == 2)
        state[2] = (board_2d == 0)
        
        return torch.FloatTensor(state).unsqueeze(0).to(self.device)
    
    def get_valid_moves(self, board):
        return [col for col in range(self.cols) if board[col] == 0]
    
    def select_action(self, state, valid_moves, training=True):
        if training and random.random() < self.epsilon:
            return random.choice(valid_moves)
        
        with torch.no_grad():
            q_values, policy = self.network(state)
            
            # Combine Q-values and policy probabilities
            combined_values = 0.7 * q_values + 0.3 * policy
            
            # Mask invalid moves
            mask = torch.ones(self.cols) * float('-inf')
            mask = mask.to(self.device)
            for move in valid_moves:
                mask[move] = 0
            
            combined_values = combined_values + mask
            return torch.argmax(combined_values).item()
    
    def calculate_reward(self, board, action, done, game_reward, player):
        """Calculate immediate reward based on action and board state"""
        if done:
            return game_reward * 10  # Amplify terminal rewards
            
        # Get intermediate reward based on position evaluation
        intermediate_reward = self.evaluate_position(board, player)
        
        # Penalize choosing invalid moves
        if action not in self.get_valid_moves(board):
            intermediate_reward -= 5
            
        return intermediate_reward
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.cat(states).to(self.device)
        next_states = torch.cat(next_states).to(self.device)
        actions = torch.tensor(actions, dtype=torch.long).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float).to(self.device)
        
        # Get current Q values and policy
        current_q, current_policy = self.network(states)
        current_q_values = current_q.gather(1, actions.unsqueeze(1))
        
        # Get next Q values using target network
        with torch.no_grad():
            next_q, next_policy = self.target_network(next_states)
            next_q_values = next_q.max(1)[0]
        
        # Calculate target Q values
        target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        # Calculate losses
        value_loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
        policy_loss = F.cross_entropy(current_policy, F.softmax(current_q.detach(), dim=1))
        
        # Combined loss
        total_loss = value_loss + 0.1 * policy_loss
        
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)  # Add gradient clipping
        self.optimizer.step()
        
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        return total_loss.item()

def train_improved_agent():
    from kaggle_environments import make
    
    env = make("connectx", debug=True)
    configuration = env.configuration
    agent = ImprovedDQNAgent(rows=configuration.rows, cols=configuration.columns)
    
    n_episodes = 2000
    target_update_frequency = 50
    rewards_history = []
    
    for episode in range(n_episodes):
        trainer = env.train([None, "random"])
        observation = trainer.reset()
        
        done = False
        total_reward = 0
        step_count = 0
        
        while not done:
            board = np.array(observation.board)
            state = agent.get_state(board)
            valid_moves = agent.get_valid_moves(board)
            
            action = agent.select_action(state, valid_moves)
            observation_next, reward, done, info = trainer.step(action)
            
            # Calculate enhanced reward
            enhanced_reward = agent.calculate_reward(
                board, action, done, reward, observation.mark
            )
            
            next_state = agent.get_state(np.array(observation_next.board))
            agent.remember(state, action, enhanced_reward, next_state, done)
            
            loss = agent.replay()
            
            total_reward += enhanced_reward
            observation = observation_next
            step_count += 1
            
        # Update target network
        if episode % target_update_frequency == 0:
            agent.target_network.load_state_dict(agent.network.state_dict())
        
        rewards_history.append(total_reward)
        
        # Print progress
        if episode % 100 == 0:
            avg_reward = np.mean(rewards_history[-100:])
            print(f"Episode {episode}/{n_episodes}, "
                  f"Average Reward: {avg_reward:.3f}, "
                  f"Epsilon: {agent.epsilon:.3f}, "
                  f"Steps: {step_count}")
    
    return agent

In [121]:
# alphazero
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from collections import defaultdict
import random


def create_new_env(board, mark):
    from kaggle_environments import make
    env = make("connectx", debug=True)
    env.reset()
    env.state[0].board = board.copy()
    env.state[0].mark = mark
    return env.train([None, "random"])


class ConnectFourNet(nn.Module):
    def __init__(self, rows=6, cols=7):
        super(ConnectFourNet, self).__init__()
        self.rows = rows
        self.cols = cols
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        
        # Policy head
        self.policy_conv = nn.Conv2d(64, 32, 1)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_fc = nn.Linear(32 * rows * cols, cols)
        
        # Value head
        self.value_conv = nn.Conv2d(64, 32, 1)
        self.value_bn = nn.BatchNorm2d(32)
        self.value_fc1 = nn.Linear(32 * rows * cols, 64)
        self.value_fc2 = nn.Linear(64, 1)
        
    def forward(self, x):
        # Shared layers
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Policy head
        policy = F.relu(self.policy_bn(self.policy_conv(x)))
        policy = policy.view(policy.size(0), -1)
        policy = self.policy_fc(policy)
        
        # Value head
        value = F.relu(self.value_bn(self.value_conv(x)))
        value = value.view(value.size(0), -1)
        value = F.relu(self.value_fc1(value))
        value = torch.tanh(self.value_fc2(value))
        
        return policy, value

class Node:
    def __init__(self, prior=0):
        self.visit_count = 0
        self.prior = prior
        self.value_sum = 0
        self.children = {}
        self.state = None
        self.to_play = None
    
    def expanded(self):
        return len(self.children) > 0
    
    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

class MCTS:
    def __init__(self, network, num_simulations=50, c_puct=1.0):
        self.network = network
        self.num_simulations = num_simulations
        self.c_puct = c_puct
        self.cols = 7
        
    def get_state_tensor(self, state, rows=6, cols=7):
        tensor = np.zeros((3, rows, cols), dtype=np.float32)
        board = np.array(state).reshape(rows, cols)
        
        tensor[0] = (board == 1)
        tensor[1] = (board == 2)
        tensor[2] = (board == 0)
        
        return torch.FloatTensor(tensor).unsqueeze(0)
    
    def select_action(self, node, temperature=1):
        visit_counts = np.array([child.visit_count for child in node.children.values()])
        actions = [int(action) for action in node.children.keys()]  # Convert to int
        
        if temperature == 0:
            action = actions[np.argmax(visit_counts)]
        else:
            visit_count_distribution = visit_counts ** (1 / temperature)
            visit_count_distribution = visit_count_distribution / sum(visit_count_distribution)
            action = int(np.random.choice(actions, p=visit_count_distribution))  # Convert to int
            
        return action
    
    def ucb_score(self, parent, child):
        prior_score = self.c_puct * child.prior * math.sqrt(parent.visit_count) / (child.visit_count + 1)
        value_score = -child.value()
        return value_score + prior_score
    
    def select_child(self, node):
        _, action, child = max((self.ucb_score(node, child), int(action), child)  # Convert to int
                             for action, child in node.children.items())
        return action, child
    
    def backpropagate(self, search_path, value):
        for node in search_path:
            node.value_sum += value
            node.visit_count += 1
            value = -value
            
    def run(self, state, to_play, valid_moves, env):
        root = Node()
        root.state = state.copy()
        root.to_play = to_play
        
        # Evaluate the root state
        state_tensor = self.get_state_tensor(state).to(next(self.network.parameters()).device)
        policy_logits, value = self.network(state_tensor)
        policy = F.softmax(policy_logits, dim=1).squeeze(0).detach().cpu().numpy()
        value = value.item()
        
        # Initialize children for root (with integer actions)
        for action in valid_moves:
            if action < len(policy):
                root.children[int(action)] = Node(prior=policy[action])  # Convert to int
        
        # Run MCTS simulations
        for _ in range(self.num_simulations):
            node = root
            search_path = [node]
            current_state = root.state.copy()
            current_to_play = root.to_play
            
            # Select
            while node.expanded():
                action, node = self.select_child(node)
                
                # Create new environment and simulate action
                test_env = create_new_env(current_state, current_to_play)
                obs, reward, done, _ = test_env.step(int(action))  # Convert to int
                
                current_state = obs.board
                current_to_play = obs.mark
                search_path.append(node)
                
                if done:
                    break
            
            # If not done, expand the node
            if not done:
                state_tensor = self.get_state_tensor(current_state).to(next(self.network.parameters()).device)
                policy_logits, value = self.network(state_tensor)
                policy = F.softmax(policy_logits, dim=1).squeeze(0).detach().cpu().numpy()
                valid_moves = [c for c in range(self.cols) if current_state[c] == 0]
                
                # Only add children for valid moves within policy range (with integer actions)
                for move in valid_moves:
                    if move < len(policy):
                        node.children[int(move)] = Node(prior=policy[move])  # Convert to int
                value = value.item()
            else:
                value = reward
            
            self.backpropagate(search_path, value)
        
        return root

class AlphaZeroTrainer:
    def __init__(self, network, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.network = network.to(device)
        self.optimizer = torch.optim.Adam(network.parameters(), lr=0.001, weight_decay=1e-4)
        self.device = device
        self.mcts = MCTS(network)
    
    def self_play(self, num_games=100):
        from kaggle_environments import make
        game_memories = []

        for game in range(num_games):
            # Initialize environment with opponent as "random" to allow for self-play
            env = make("connectx", debug=True)
            trainer = env.train([None, "random"])  # Use random as opponent

            memories = []
            observation = trainer.reset()

            while not env.done:
                state = observation.board
                to_play = observation.mark
                valid_moves = [c for c in range(env.configuration.columns) if state[c] == 0]

                if not valid_moves:  # Check for draw
                    break

                # MCTS
                temperature = 1.0 if len(memories) < 10 else 0.1
                root = self.mcts.run(state, to_play, valid_moves, trainer)

                # Store search statistics for training
                policy = np.zeros(env.configuration.columns)
                for action, child in root.children.items():
                    policy[action] = child.visit_count
                policy = policy / (np.sum(policy) + 1e-8)  # Add small epsilon to avoid division by zero

                # Select action
                action = self.mcts.select_action(root, temperature)

                # Store state
                memories.append({
                    'state': state,
                    'policy': policy,
                    'to_play': to_play
                })

                # Make move
                observation, reward, done, _ = trainer.step(action)

            # Process game result
            final_reward = reward if reward is not None else 0

            # Store game result with proper reward assignment
            for memory in memories:
                player_reward = final_reward if memory['to_play'] == 1 else -final_reward
                game_memories.append((
                    memory['state'],
                    memory['policy'],
                    player_reward,
                    memory['to_play']
                ))

            if game % 10 == 0:
                print(f"Completed self-play game {game + 1}/{num_games}")

        return game_memories
    
    def train(self, memories, batch_size=32, epochs=1):
        states, policies, values = [], [], []
        
        # Prepare training data
        for state, policy, reward, to_play in memories:
            states.append(self.mcts.get_state_tensor(state))
            policies.append(torch.FloatTensor(policy))
            values.append(torch.FloatTensor([reward]))
        
        dataset = torch.utils.data.TensorDataset(
            torch.cat(states, dim=0),
            torch.stack(policies),
            torch.cat(values)
        )
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        self.network.train()
        total_loss = 0
        
        for _ in range(epochs):
            for state, policy_target, value_target in loader:
                state = state.to(self.device)
                policy_target = policy_target.to(self.device)
                value_target = value_target.to(self.device)
                
                # Forward pass
                policy_logits, value = self.network(state)
                
                # Calculate loss
                policy_loss = F.cross_entropy(policy_logits, policy_target)
                value_loss = F.mse_loss(value.squeeze(), value_target)
                loss = policy_loss + value_loss
                
                # Backward pass
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
        
        return total_loss / len(loader)

def train_alphazero():
    # Initialize network and trainer
    network = ConnectFourNet()
    trainer = AlphaZeroTrainer(network)
    
    num_iterations = 50
    games_per_iteration = 20
    
    for iteration in range(num_iterations):
        print(f"\nStarting iteration {iteration + 1}/{num_iterations}")
        
        # Self-play
        memories = trainer.self_play(num_games=games_per_iteration)
        
        # Train
        loss = trainer.train(memories)
        print(f"Iteration {iteration + 1} completed. Loss: {loss:.4f}")
        
        # Save model periodically
        torch.save(network.state_dict(), f'alphazero_connect4_iter_{iteration+1}.pth')
    
    return trainer



In [None]:
trained_model = train_alphazero()
torch.save(trained_model.network.state_dict(), 'alphazero_connect4.pth')


Starting iteration 1/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 1 completed. Loss: 2.5716

Starting iteration 2/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 2 completed. Loss: 3.0706

Starting iteration 3/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 3 completed. Loss: 2.2894

Starting iteration 4/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 4 completed. Loss: 1.9650

Starting iteration 5/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 5 completed. Loss: 2.3211

Starting iteration 6/50


  value_loss = F.mse_loss(value.squeeze(), value_target)


Completed self-play game 1/20
Completed self-play game 11/20
Iteration 6 completed. Loss: 2.7470

Starting iteration 7/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 7 completed. Loss: 2.2952

Starting iteration 8/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 8 completed. Loss: 1.8858

Starting iteration 9/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 9 completed. Loss: 1.8558

Starting iteration 10/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 10 completed. Loss: 1.6801

Starting iteration 11/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 11 completed. Loss: 1.4667

Starting iteration 12/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 12 completed. Loss: 1.5736

Starting iteration 13/50
Completed self-play game 1/20
Completed self-play game 11/20
Iteration 13 completed. Loss: 1.5343

Starting iteration 14/50
Completed self-

In [125]:
def my_agent(observation, configuration):
    global network, mcts
    
    if 'network' not in globals():
        network = ConnectFourNet(rows=configuration.rows, cols=configuration.columns)
        # Load trained model
        network.load_state_dict(torch.load('alphazero_connect4_iter_35.pth'))
        network.eval()
        mcts = MCTS(network, num_simulations=20)  # Fewer simulations for competition
    
    state = observation.board
    valid_moves = [c for c in range(configuration.columns) if state[c] == 0]
    
    if len(valid_moves) == 1:
        return valid_moves[0]
    
    # Create a dummy environment for MCTS
    from kaggle_environments import make
    env = make("connectx", debug=True)
    env.reset()
    env.state[0].board = observation.board
    env.state[0].mark = observation.mark
    
    root = mcts.run(state, observation.mark, valid_moves, env)
    return mcts.select_action(root, temperature=0)  # Use temperature=0 for deterministic play


In [91]:
# Define the agent function for submission
def my_agent(observation, configuration):
    # Load the trained model (assuming it's been trained and saved)
    global dqn_agent
    
    if 'dqn_agent' not in globals():
        dqn_agent = DQNAgent(rows=configuration.rows, cols=configuration.columns)
        #Uncomment to load pre-trained model
        dqn_agent.load_model('connect_four_dqn.pth')
    
    board = np.array(observation.board)
    state = dqn_agent.get_state(board)
    valid_moves = dqn_agent.get_valid_moves(board)
    
    return dqn_agent.select_action(state, valid_moves, training=False)

# To train the agent, uncomment these lines:
# if __name__ == "__main__":
#     trained_agent = train_dqn_agent()
#     trained_agent.save_model('connect_four_dqn.pth')

In [126]:
env.reset()
# Play as the first agent against default "random" agent.
env.run([my_agent, "random"])
env.render(mode="ipython", width=500, height=450)

In [None]:
def mean_reward(rewards):
    return sum(r[0] for r in rewards) / float(len(rewards))

# Run multiple episodes to estimate its performance.
print("My Agent vs Random Agent:", mean_reward(evaluate("connectx", [my_agent, "random"], num_episodes=10)))
print("My Agent vs Negamax Agent:", mean_reward(evaluate("connectx", [my_agent, "negamax"], num_episodes=10)))
# print("Random Agent vs My Agent:", mean_reward(evaluate("connectx", ["random", my_agent], num_episodes=10)))
# print("Negamax Agent vs My Agent:", mean_reward(evaluate("connectx", ["negamax", my_agent], num_episodes=10)))

My Agent vs Random Agent: 0.8
