In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import tkinter as tk
from tkinter import messagebox

# Define device for PyTorch (use GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Board size (12x12 grid)
BOARD_SIZE = 12

# Size of each cell in pixels for GUI
CELL_SIZE = 40

# Color definitions for GUI elements
COLORS = {
    "bg": "#F0D9B5",      # Background color
    "grid": "#000000",    # Grid line color
    "valid": "#FFFFFF",   # Color for valid/empty cells
    "p1": "#2C5F2D",     # Player 1 piece color (green)
    "p2": "#B1624E"      # Player 2 piece color (red)
}

# Define valid move zones (cross-shaped regions, excluding corners)
CROSS_ZONES = [
    (4, 8, 4, 8),   # Central zone
    (0, 4, 4, 8),   # Left zone
    (8, 12, 4, 8),  # Right zone
    (4, 8, 0, 4),   # Top zone
    (4, 8, 8, 12)   # Bottom zone
]

class State:
    """Class to manage the game state, board, and logic."""
    def __init__(self):
        """Initialize the game state."""
        self.reset()
        
    def reset(self):
        """Reset the board to empty, clear winner and game end status."""
        self.data = np.zeros((BOARD_SIZE, BOARD_SIZE), dtype=int)  # 12x12 board (0: empty, 1: P1, -1: P2)
        self.winner = None  # Winner (1, -1, or 0 for draw)
        self.end = False    # Game end flag
        
    def is_valid_position(self, row, col):
        """Check if a position is within valid cross zones and empty."""
        for x1, x2, y1, y2 in CROSS_ZONES:
            if x1 <= row < x2 and y1 <= col < y2:
                return True
        return False
    
    # def check_consecutive(self, player, count=3):
    #     directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
    #     for i in range(BOARD_SIZE):
    #         for j in range(BOARD_SIZE):
    #             for dx, dy in directions:
    #                 if 0 <= i + dx*(count-1) < BOARD_SIZE and 0 <= j + dy*(count-1) < BOARD_SIZE:
    #                     seq = [self.data[i + dx*k][j + dy*k] for k in range(count)]
    #                     if all(cell == player for cell in seq):
    #                         return True
    #     return False
    
    # def check_consecutive(self, player, count, directions=None):
    #     if torch.sum(self.data == player).item() < count:  # Early exit if not enough pieces
    #         return False
        
    #     # Default to all directions if not specified
    #     if directions is None:
    #         directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
        
    #     for i in range(BOARD_SIZE):
    #         for j in range(BOARD_SIZE):
    #             for dx, dy in directions:
    #                 if 0 <= i + dx*(count-1) < BOARD_SIZE and 0 <= j + dy*(count-1) < BOARD_SIZE:
    #                     seq = [self.data[i + dx*k, j + dy*k].item() for k in range(count)]
    #                     if all(cell == player for cell in seq):
    #                         return True
    #     return False
    def check_consecutive(self, player, count, directions=None, row=None, col=None):
        """
        Check for 'count' consecutive pieces for a player in specified directions.
        Returns (has_consecutive, can_extend) to indicate sequence existence and extendability.
        """
        # Set default directions if none provided (horizontal, vertical, diagonals)
        if directions is None:
            directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
        
        # Restrict check to a local region around (row, col) if provided
        if row is not None and col is not None:
            i_range = range(max(0, row-count+1), min(BOARD_SIZE, row+count))
            j_range = range(max(0, col-count+1), min(BOARD_SIZE, col+count))
        else:
            i_range = range(BOARD_SIZE)
            j_range = range(BOARD_SIZE)
        
        has_consecutive = False  # Flag for found sequence
        can_extend = False       # Flag for extendable sequence
        
        # Iterate over board positions and directions
        for i in i_range:
            for j in j_range:
                for dx, dy in directions:
                    # Check if sequence fits within board bounds
                    if 0 <= i + dx*(count-1) < BOARD_SIZE and 0 <= j + dy*(count-1) < BOARD_SIZE:
                        # Extract sequence of 'count' cells
                        seq = [self.data[i + dx*k, j + dy*k].item() for k in range(count)]
                        if all(cell == player for cell in seq):
                            has_consecutive = True
                            # Check extendability for 3-consecutive or diagonal 4-consecutive
                            if count == 3 or (count == 4 and (dx, dy) in [(1, 1), (1, -1)]):
                                extendable = False
                                # Check left/up end
                                left_i, left_j = i - dx, j - dy
                                # Check right/down end
                                right_i, right_j = i + dx*count, j + dy*count
                                # Sequence is extendable if at least one end is empty
                                if (0 <= left_i < BOARD_SIZE and 0 <= left_j < BOARD_SIZE and
                                    self.data[left_i, left_j] == 0):
                                    extendable = True
                                elif (0 <= right_i < BOARD_SIZE and 0 <= right_j < BOARD_SIZE and
                                      self.data[right_i, right_j] == 0):
                                    extendable = True
                                if extendable:
                                    can_extend = True
                            else:
                                can_extend = True  # 2-consecutive always extendable
                            if has_consecutive and can_extend:
                                return True, True
        return has_consecutive, can_extend
    
    # def update_state(self, row, col, player):
    #     if self.is_valid_position(row, col) and self.data[row][col] == 0:
    #         self.data[row][col] = player
    #         reward = 0
    #         self.check_winner()
            
    #         if self.end:
    #             reward = 1.0 if self.winner == player else -1.0
    #         elif self.check_consecutive(-player, 3):
    #             reward += -0.2
    #         elif self.check_consecutive(-player, 2):
    #             reward += -0.1
            
    #         return True, reward
    #     return False, 0

    # def update_state(self, row, col, player):
    #     if self.is_valid_position(row, col) and self.data[row][col] == 0:
    #         self.data[row][col] = player
    #         reward = 0.0
    #         reward_scale = 0.5  # Scale factor to balance rewards
            
    #         # Check game end and winner
    #         self.check_winner()
    #         if self.end:
    #             reward = 1.0 if self.winner == player else -1.0 if self.winner == -player else 0.0
    #             return True, reward
            
    #         # Positive rewards for player's consecutive pieces
    #         if self.check_consecutive(player, 3):
    #             reward += 0.15 * reward_scale  # Reward for 3 consecutive pieces
    #         elif self.check_consecutive(player, 2):
    #             reward += 0.05 * reward_scale  # Reward for 2 consecutive pieces
                
    #         # Negative rewards for opponent's consecutive pieces
    #         if self.check_consecutive(-player, 3):
    #             reward += -0.2 * reward_scale  # Penalty for opponent's 3 consecutive pieces
    #         elif self.check_consecutive(-player, 2):
    #             reward += -0.1 * reward_scale  # Penalty for opponent's 2 consecutive pieces
            
    #         return True, reward
    #     return False, 0.0
    # def update_state(self, row, col, player):
    #     if self.is_valid_position(row, col) and self.data[row, col] == 0:
    #         # Check consecutive pieces before move
    #         hv_directions = [(0, 1), (1, 0)]  # Horizontal/Vertical
    #         prev_player_3 = self.check_consecutive(player, 3, hv_directions)
    #         prev_player_2 = self.check_consecutive(player, 2, hv_directions) and not prev_player_3
    #         prev_opponent_3 = self.check_consecutive(-player, 3, hv_directions)
    #         prev_opponent_2 = self.check_consecutive(-player, 2, hv_directions) and not prev_opponent_3

    #         # Update board
    #         self.data[row, col] = player
    #         reward = 0.0
    #         reward_scale = 0.5
            
    #         # Check game end and winner
    #         self.check_winner()
    #         if self.end:
    #             reward = 1.0 if self.winner == player else -1.0 if self.winner == -player else 0.0
    #             return True, reward
            
    #         # Check consecutive pieces after move
    #         curr_player_3 = self.check_consecutive(player, 3, hv_directions)
    #         curr_player_2 = self.check_consecutive(player, 2, hv_directions) and not curr_player_3
    #         curr_opponent_3 = self.check_consecutive(-player, 3, hv_directions)
    #         curr_opponent_2 = self.check_consecutive(-player, 2, hv_directions) and not curr_opponent_3
            
    #         # Positive rewards for newly formed consecutive pieces
    #         if curr_player_3 and not prev_player_3:
    #             reward += 0.10 * reward_scale  # Reward for new 3 consecutive pieces
    #         elif curr_player_2 and not prev_player_2:
    #             reward += 0.05 * reward_scale  # Reward for new 2 consecutive pieces
                
    #         # Negative rewards for opponent's newly formed consecutive pieces
    #         if curr_opponent_3 and not prev_opponent_3:
    #             reward += -0.15 * reward_scale  # Penalty for opponent's new 3 consecutive pieces
    #         elif curr_opponent_2 and not prev_opponent_2:
    #             reward += -0.10 * reward_scale  # Penalty for opponent's new 2 consecutive pieces
            
    #         return True, reward
    #     return False, 0.0
    
    def update_state(self, row, col, player):
        """Update board with player's move, calculate rewards, and check for winner."""
        if self.is_valid_position(row, col) and self.data[row, col] == 0:
            # Check consecutive pieces before the move
            hv_directions = [(0, 1), (1, 0)]  # Horizontal/Vertical directions
            diag_directions = [(1, 1), (1, -1)]  # Diagonal directions
            # Player's pre-move consecutive checks
            prev_player_3_hv, prev_player_3_hv_ext = self.check_consecutive(player, 3, hv_directions, row, col)
            prev_player_2_hv, prev_player_2_hv_ext = self.check_consecutive(player, 2, hv_directions, row, col)
            prev_player_3_diag, prev_player_3_diag_ext = self.check_consecutive(player, 3, diag_directions, row, col)
            prev_player_2_diag, prev_player_2_diag_ext = self.check_consecutive(player, 2, diag_directions, row, col)
            prev_player_4_diag, prev_player_4_diag_ext = self.check_consecutive(player, 4, diag_directions, row, col)
            # Opponent's pre-move consecutive checks
            prev_opponent_3_hv, prev_opponent_3_hv_ext = self.check_consecutive(-player, 3, hv_directions, row, col)
            prev_opponent_2_hv, prev_opponent_2_hv_ext = self.check_consecutive(-player, 2, hv_directions, row, col)
            prev_opponent_3_diag, prev_opponent_3_diag_ext = self.check_consecutive(-player, 3, diag_directions, row, col)
            prev_opponent_2_diag, prev_opponent_2_diag_ext = self.check_consecutive(-player, 2, diag_directions, row, col)
            prev_opponent_4_diag, prev_opponent_4_diag_ext = self.check_consecutive(-player, 4, diag_directions, row, col)

            # Place player's piece on the board
            self.data[row, col] = player
            reward = 0.0
            reward_scale = 0.5  # Scale factor for balancing rewards
            
            # Check for game end and winner
            self.check_winner()
            if self.end:
                reward = 1.0 if self.winner == player else -1.0 if self.winner == -player else 0.0
                return True, reward
            
            # Check consecutive pieces after the move
            curr_player_3_hv, curr_player_3_hv_ext = self.check_consecutive(player, 3, hv_directions, row, col)
            curr_player_2_hv, curr_player_2_hv_ext = self.check_consecutive(player, 2, hv_directions, row, col)
            curr_player_3_diag, curr_player_3_diag_ext = self.check_consecutive(player, 3, diag_directions, row, col)
            curr_player_2_diag, curr_player_2_diag_ext = self.check_consecutive(player, 2, diag_directions, row, col)
            curr_player_4_diag, curr_player_4_diag_ext = self.check_consecutive(player, 4, diag_directions, row, col)
            curr_opponent_3_hv, curr_opponent_3_hv_ext = self.check_consecutive(-player, 3, hv_directions, row, col)
            curr_opponent_2_hv, curr_opponent_2_hv_ext = self.check_consecutive(-player, 2, hv_directions, row, col)
            curr_opponent_3_diag, curr_opponent_3_diag_ext = self.check_consecutive(-player, 3, diag_directions, row, col)
            curr_opponent_2_diag, curr_opponent_2_diag_ext = self.check_consecutive(-player, 2, diag_directions, row, col)
            curr_opponent_4_diag, curr_opponent_4_diag_ext = self.check_consecutive(-player, 4, diag_directions, row, col)
            
            # Assign positive rewards for player's new extendable consecutive pieces
            if curr_player_3_hv and not prev_player_3_hv and curr_player_3_hv_ext:
                reward += 0.15 * reward_scale  # Reward for new extendable 3-consecutive (horizontal/vertical)
            elif curr_player_2_hv and not prev_player_2_hv and curr_player_2_hv_ext:
                reward += 0.05 * reward_scale  # Reward for new 2-consecutive (horizontal/vertical)
            if curr_player_4_diag and not prev_player_4_diag and curr_player_4_diag_ext:
                reward += 0.12 * reward_scale  # Reward for new extendable 4-consecutive (diagonal)
            elif curr_player_3_diag and not prev_player_3_diag and curr_player_3_diag_ext:
                reward += 0.08 * reward_scale  # Reward for new extendable 3-consecutive (diagonal)
            elif curr_player_2_diag and not prev_player_2_diag and curr_player_2_diag_ext:
                reward += 0.03 * reward_scale  # Reward for new 2-consecutive (diagonal)
                
            # Assign negative rewards for opponent's new extendable consecutive pieces
            if curr_opponent_3_hv and not prev_opponent_3_hv and curr_opponent_3_hv_ext:
                reward += -0.18 * reward_scale  # Penalty for opponent's new extendable 3-consecutive (horizontal/vertical)
            elif curr_opponent_2_hv and not prev_opponent_2_hv and curr_opponent_2_hv_ext:
                reward += -0.05 * reward_scale  # Penalty for opponent's new 2-consecutive (horizontal/vertical)
            elif curr_opponent_4_diag and not prev_opponent_4_diag and curr_opponent_4_diag_ext:
                reward += -0.15 * reward_scale  # Penalty for opponent's new extendable 4-consecutive (diagonal)
            elif curr_opponent_3_diag and not prev_opponent_3_diag and curr_opponent_3_diag_ext:
                reward += -0.10 * reward_scale  # Penalty for opponent's new extendable 3-consecutive (diagonal)
            elif curr_opponent_2_diag and not prev_opponent_2_diag and curr_opponent_2_diag_ext:
                reward += -0.03 * reward_scale  # Penalty for opponent's new 2-consecutive (diagonal)
            
            return True, reward
        return False, 0.0
    
    def check_winner(self):
            """Check for a winner (4 horizontal/vertical or 5 diagonal) or draw."""
            # Check horizontal and vertical 4-consecutive
            for i in range(BOARD_SIZE):
                for j in range(BOARD_SIZE - 3):
                    if abs(sum(self.data[i, j:j+4])) == 4:
                        self.winner = self.data[i][j]
                        self.end = True
                        return
                    if abs(sum(self.data[j:j+4, i])) == 4:
                        self.winner = self.data[j][i]
                        self.end = True
                        return
            
            # Check diagonal 5-consecutive
            for i in range(BOARD_SIZE - 4):
                for j in range(BOARD_SIZE - 4):
                    diag = [self.data[i+k][j+k] for k in range(5)]
                    if abs(sum(diag)) == 5:
                        self.winner = diag[0]
                        self.end = True
                        return
                    anti_diag = [self.data[i+k][j+4-k] for k in range(5)]
                    if abs(sum(anti_diag)) == 5:
                        self.winner = anti_diag[0]
                        self.end = True
                        return
            
            # Check for draw (board full)
            if np.all(self.data != 0):
                self.end = True
                self.winner = 0

class EnhancedDQN(nn.Module):
    """Deep Q-Network model for predicting action Q-values."""
    def __init__(self, hidden_size=512):
        """Initialize the DQN with convolutional and fully connected layers."""
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),  # Conv layer: 1 input channel, 64 filters
            nn.BatchNorm2d(64),              # Batch normalization
            nn.ReLU(),                       # ReLU activation
            nn.Conv2d(64, 128, 3, padding=1),  # Conv layer: 64 to 128 filters
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, padding=1),  # Conv layer: 128 to 256 filters
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(256 * BOARD_SIZE * BOARD_SIZE, hidden_size),  # FC layer: flatten to hidden_size
            nn.ReLU(),
            nn.Dropout(0.3),  # Dropout to prevent overfitting
            nn.Linear(hidden_size, BOARD_SIZE**2)  # Output: Q-values for 144 actions
        )
        
    def forward(self, x):
        """Forward pass through the network."""
        x = self.conv(x.unsqueeze(1))  # Add channel dimension and apply conv layers
        return self.fc(x.view(x.size(0), -1))  # Flatten and apply FC layers

class RLAgent:
    """Reinforcement Learning agent using DQN."""
    def __init__(self):
        """Initialize the agent with main and target models, optimizer, and memory."""
        self.model = EnhancedDQN().to(device)  # Main DQN model
        self.target_model = EnhancedDQN().to(device)  # Target DQN model
        self.target_model.load_state_dict(self.model.state_dict())  # Sync target with main
        self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=1e-5)  # AdamW optimizer
        self.memory = deque(maxlen=50000)  # Experience replay memory
        self.batch_size = 256  # Batch size for training
        self.gamma = 0.99  # Discount factor
        self.epsilon = 1.0  # Initial exploration rate
        self.epsilon_min = 0.05  # Minimum exploration rate
        self.epsilon_decay = 0.999  # Exploration decay rate
        self.update_freq = 100  # Frequency to update target model
        self.steps = 0  # Training steps counter
        
    def get_action(self, state, training=True):
        """Select an action using epsilon-greedy or model Q-values."""
        valid_actions = self._get_valid_actions(state)  # Get valid moves
        if not valid_actions:
            return None
        
        if training and np.random.rand() < self.epsilon:
            action = random.choice(valid_actions)  # Random action for exploration
        else:
            state_tensor = torch.FloatTensor(state.data).to(device)  # Convert state to tensor
            with torch.no_grad():
                q_values = self.model(state_tensor.unsqueeze(0)).squeeze()  # Get Q-values
            mask = torch.full((BOARD_SIZE**2,), -np.inf, device=device)  # Mask invalid actions
            mask[valid_actions] = q_values[valid_actions]
            action = torch.argmax(mask).item()  # Select action with max Q-value
        
        return action
    
    def _get_valid_actions(self, state):
        """Get list of valid actions (indices of empty cells in valid zones)."""
        return [i * BOARD_SIZE + j for i in range(BOARD_SIZE)
                for j in range(BOARD_SIZE)
                if state.is_valid_position(i, j) and state.data[i][j] == 0]
    
    def store_experience(self, state, action, reward, next_state, done):
        """Store a transition in the replay memory."""
        prev_state = State()
        prev_state.data = np.copy(state.data)  # Copy current state
        prev_state.end = state.end
        prev_state.winner = state.winner
        
        next_state_copy = None
        if next_state:
            next_state_copy = State()
            next_state_copy.data = np.copy(next_state.data)  # Copy next state
            next_state_copy.end = next_state.end
            next_state_copy.winner = next_state.winner
        
        # Store transition as tensors
        self.memory.append((
            torch.FloatTensor(prev_state.data).to(device),
            action,
            reward,
            torch.FloatTensor(next_state_copy.data).to(device) if next_state_copy else None,
            done
        ))
    
    def train_step(self):
        """Perform one training step using a batch from memory."""
        if len(self.memory) < self.batch_size:
            return 0.0  # Skip if not enough experiences
        
        # Sample a batch
        batch = random.sample(self.memory, self.batch_size)
        states = torch.stack([x[0] for x in batch]).to(device)
        actions = torch.LongTensor([x[1] for x in batch]).to(device)
        rewards = torch.FloatTensor([x[2] for x in batch]).to(device)
        next_states = torch.stack([x[3] for x in batch if x[3] is not None]).to(device)
        dones = torch.BoolTensor([x[4] for x in batch]).to(device)
        
        # Compute current Q-values
        current_q = self.model(states).gather(1, actions.unsqueeze(1))
        
        # Compute target Q-values using target model
        with torch.no_grad():
            next_actions = self.model(next_states).max(1)[1]
            next_q = self.target_model(next_states).gather(1, next_actions.unsqueeze(1))
        
        # Calculate target Q-values
        target_q = rewards.clone()
        not_done = (~dones)
        if not_done.any():
            target_q[not_done] += self.gamma * next_q.squeeze()[:sum(not_done)]
        
        # Compute loss and optimize
        loss = nn.SmoothL1Loss()(current_q.squeeze(), target_q.detach())
        
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)  # Gradient clipping
        self.optimizer.step()
        
        # Update exploration rate
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        self.steps += 1
        
        # Update target model periodically
        if self.steps % self.update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        
        return loss.item()

class GameGUI:
    """Graphical User Interface for human-AI gameplay."""
    def __init__(self, master, ai_path=None):
        """Initialize the GUI with game state and AI agent."""
        self.master = master
        self.state = State()  # Game state
        self.ai = RLAgent()   # AI agent
        if ai_path:
            try:
                self.ai.model.load_state_dict(torch.load(ai_path))  # Load trained model
            except Exception:
                pass
        self.ai.epsilon = 0.0  # Disable exploration for AI moves
        
        self.setup_ui()
        
    def setup_ui(self):
        """Set up the Tkinter canvas and event bindings."""
        self.master.title("Super Tic-Tac-Toe")
        self.canvas = tk.Canvas(self.master, 
                               width=CELL_SIZE*BOARD_SIZE,
                               height=CELL_SIZE*BOARD_SIZE,
                               bg=COLORS["bg"])
        self.canvas.pack()
        self.draw_board()
        self.canvas.bind("<Button-1>", self.on_click)  # Bind mouse clicks
        
    def draw_board(self):
        """Draw the board with valid zones, grid, and pieces."""
        # Draw valid cross zones
        for zone in CROSS_ZONES:
            x1, x2, y1, y2 = zone
            self.canvas.create_rectangle(
                x1*CELL_SIZE, y1*CELL_SIZE,
                x2*CELL_SIZE, y2*CELL_SIZE,
                fill=COLORS["valid"], outline=COLORS["grid"])

        # Draw grid lines
        for i in range(BOARD_SIZE+1):
            self.canvas.create_line(0, i*CELL_SIZE, 
                                  BOARD_SIZE*CELL_SIZE, i*CELL_SIZE,
                                  fill=COLORS["grid"])
            self.canvas.create_line(i*CELL_SIZE, 0,
                                  i*CELL_SIZE, BOARD_SIZE*BOARD_SIZE,
                                  fill=COLORS["grid"])

        # Initialize piece ovals
        self.pieces = {}
        for i in range(BOARD_SIZE):
            for j in range(BOARD_SIZE):
                x = j*CELL_SIZE + CELL_SIZE//2
                y = i*CELL_SIZE + CELL_SIZE//2
                self.pieces[(i,j)] = self.canvas.create_oval(
                    x-15, y-15, x+15, y+15,
                    fill=COLORS["valid"], outline=COLORS["valid"])
    
    def update_display(self):
        """Update the display to reflect current board state."""
        for i in range(BOARD_SIZE):
            for j in range(BOARD_SIZE):
                color = COLORS["valid"]
                if self.state.data[i][j] == 1:
                    color = COLORS["p1"]  # Player 1 piece
                elif self.state.data[i][j] == -1:
                    color = COLORS["p2"]  # Player 2 piece
                self.canvas.itemconfig(self.pieces[(i,j)], fill=color, outline=color)
    
    def on_click(self, event):
        """Handle human player's mouse click."""
        if self.state.end:
            return

        col = event.x // CELL_SIZE
        row = event.y // CELL_SIZE
        
        # Randomly perturb move (50% chance)
        if np.random.rand() < 0.5:
            final_row, final_col = row, col
        else:
            candidates = [(row + dr, col + dc) for dr in (-1,0,1) for dc in (-1,0,1) 
                        if not (dr == 0 and dc == 0)]
            final_row, final_col = random.choice(candidates)
        
        valid = False
        if 0 <= final_row < BOARD_SIZE and 0 <= final_col < BOARD_SIZE:
            valid, _ = self.state.update_state(final_row, final_col, 1)  # Human move (Player 1)
        
        if valid:
            self.update_display()
            if self.state.end:
                self.game_over()
                return
            self.master.after(500, self.ai_move)  # Schedule AI move after 500ms
    
    def ai_move(self):
        """Execute AI's move."""
        action = self.ai.get_action(self.state, training=False)  # Get AI action
        if action is None:
            return
        
        row = action // BOARD_SIZE
        col = action % BOARD_SIZE
        
        # Randomly perturb move (50% chance)
        if np.random.rand() < 0.5:
            final_row, final_col = row, col
        else:
            candidates = [(row + dr, col + dc) for dr in (-1,0,1) for dc in (-1,0,1)
                        if not (dr == 0 and dc == 0)]
            final_row, final_col = random.choice(candidates)
        
        valid = False
        if 0 <= final_row < BOARD_SIZE and 0 <= final_col < BOARD_SIZE:
            valid, _ = self.state.update_state(final_row, final_col, -1)  # AI move (Player 2)
        
        if valid:
            self.update_display()
            if self.state.end:
                self.game_over()
    
    def game_over(self):
        """Display game-over message and close GUI."""
        winner = "Human" if self.state.winner == 1 else "AI" if self.state.winner == -1 else "Draw"
        messagebox.showinfo("Game Over", f"{winner} wins!")
        self.master.destroy()

def train():
    """Train the RL agent through self-play."""
    agent = RLAgent()
    try:
        for episode in range(1001):
            state = State()
            total_reward = 0
            current_player = 1  # Start with Player 1
            while not state.end:
                prev_state = State()
                prev_state.data = np.copy(state.data)  # Copy current state
                
                # Select action based on player
                if current_player == 1:
                    action = agent.get_action(state)  # AI move
                else:
                    valid_actions = agent._get_valid_actions(state)
                    action = random.choice(valid_actions) if valid_actions else None  # Random opponent move
                
                if action is None:
                    break
                
                row, col = divmod(action, BOARD_SIZE)
                # Randomly perturb move (50% chance)
                if np.random.rand() < 0.5:
                    final_row, final_col = row, col
                else:
                    candidates = [(row + dr, col + dc) for dr in (-1,0,1) for dc in (-1,0,1)
                                if not (dr == 0 and dc == 0)]
                    final_row, final_col = random.choice(candidates)
                
                valid = False
                reward = 0
                if 0 <= final_row < BOARD_SIZE and 0 <= final_col < BOARD_SIZE:
                    valid, reward = state.update_state(final_row, final_col, current_player)
                
                if valid:
                    next_state = State()
                    next_state.data = np.copy(state.data)  # Copy next state
                    done = state.end
                    
                    # Store experience and train
                    agent.store_experience(prev_state, action, reward, next_state, done)
                    loss = agent.train_step()
                    total_reward += reward
                    
                    current_player *= -1  # Switch player
                    
            if episode % 100 == 0:
                print(f"Episode {episode}, Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.3f}")
    finally:
        torch.save(agent.model.state_dict(), "best_model.pth")  # Save trained model
        print("Model saved to best_model.pth")

if __name__ == "__main__":
    train()
    root = tk.Tk()
    GameGUI(root, "best_model.pth")
    root.mainloop()