# Gomoku DQN Training on Kaggle

This notebook sets up the environment and runs the DQN training with GPU acceleration.

## Step 1: Install Dependencies

In [None]:
!pip install numpy gymnasium stable-baselines3[extra] shimmy

## Step 2: Create Project Structure
We need to recreate the backend files here.

In [None]:
import os
os.makedirs("backend/models", exist_ok=True)
os.makedirs("backend/algorithms", exist_ok=True)
os.makedirs("models", exist_ok=True)
os.makedirs("scripts", exist_ok=True)

### Write `backend/models/board.py`

In [None]:
%%writefile backend/models/board.py
from typing import List, Optional, Tuple

class Board:

    # 构造函数
    def __init__(self, size: int = 15) -> None:
        self.size: int = size
        self.board: List[List[int]] = [
            [0 for _ in range(size)] for _ in range(size)
        ] # 棋盘
        self.move_count: int = 0 # 步数

    # 重置棋盘到初始状态
    def reset(self) -> None:
        for x in range(self.size):
            for y in range(self.size):
                self.board[x][y] = 0
        self.move_count = 0

    # 判断坐标是否在棋盘范围内
    def is_inside(self, x: int, y: int) -> bool:
        return 0 <= x < self.size and 0 <= y < self.size

    # 判断该位置是否为空
    def is_empty(self, x: int, y: int) -> bool:
        return self.board[x][y] == 0
    
    # --------禁手检测--------

    # 四个基本方向：横、竖、主对角线、副对角线
    _DIRS = [(1, 0), (0, 1), (1, 1), (1, -1)]

    # 统计某个方向上的最大连续同色长度（包含 (x, y)）
    def _max_run_len(self, x: int, y: int, dx: int, dy: int, player: int) -> int:
        if not self.is_inside(x, y) or self.board[x][y] != player:
            return 0
        length = 1
        cx, cy = x + dx, y + dy
        while self.is_inside(cx, cy) and self.board[cx][cy] == player:
            length += 1
            cx += dx
            cy += dy
        cx, cy = x - dx, y - dy
        while self.is_inside(cx, cy) and self.board[cx][cy] == player:
            length += 1
            cx -= dx
            cy -= dy
        return length

    # 整盘有没有某方的“五连”（恰好 5 个连续）
    def _has_five_anywhere(self, player: int) -> bool:
        n = self.size
        for x in range(n):
            for y in range(n):
                if self.board[x][y] != player:
                    continue
                for dx, dy in self._DIRS:
                    if self._max_run_len(x, y, dx, dy, player) == 5:
                        return True
        return False

    # 判断包含 (x, y) 的某方向上是否有“直四”（两头空的连续 4 子）
    def _has_straight_four_including(self, x: int, y: int, player: int) -> bool:
        for dx, dy in self._DIRS:
            # 以 (x, y) 为中心，找到这条连续同色段的两端
            if self.board[x][y] != player:
                continue

            # 向正方向
            forward = 0
            cx, cy = x + dx, y + dy
            while self.is_inside(cx, cy) and self.board[cx][cy] == player:
                forward += 1
                cx += dx
                cy += dy
            end_fx, end_fy = cx, cy  # 正向第一个非 player 的格子

            # 向反方向
            backward = 0
            cx, cy = x - dx, y - dy
            while self.is_inside(cx, cy) and self.board[cx][cy] == player:
                backward += 1
                cx -= dx
                cy -= dy
            end_bx, end_by = cx, cy  # 反向第一个非 player 的格子

            run_len = 1 + forward + backward
            if run_len != 4:
                continue

            # 两端必须在棋盘内且为空，才是“直四”
            if not (self.is_inside(end_fx, end_fy) and self.is_inside(end_bx, end_by)):
                continue
            if self.board[end_fx][end_fy] == 0 and self.board[end_bx][end_by] == 0:
                return True
        return False

    # 统计：在已经把 (x, y) 落子为 player 的前提下，这一手产生了多少个“四”（方向数）。
    # 定义：存在某个空位 q，使得在 q 落一子后，可以得到一个“五连”，且这 5 个子中包含 (x, y)。
    def _count_fours_from_move(self, x: int, y: int, player: int) -> int:
        count = 0

        for dx, dy in self._DIRS:
            has_four_in_dir = False

            # 先收集这条线上所有坐标
            coords = []
            # 先往负方向
            cx, cy = x, y
            while self.is_inside(cx, cy):
                coords.append((cx, cy))
                cx -= dx
                cy -= dy
            coords.reverse()  # 现在 coords 从负端到 (x, y)

            # 再往正方向补全
            cx, cy = x + dx, y + dy
            while self.is_inside(cx, cy):
                coords.append((cx, cy))
                cx += dx
                cy += dy

            # 对这条线上的每一个空格，尝试落子，看是否会产生“包含 (x, y) 的五连”
            for (qx, qy) in coords:
                if self.board[qx][qy] != 0:
                    continue
                self.board[qx][qy] = player  # 假想第二手
                # 在这个假想局面下，看看 (x, y) 是否在某个方向形成恰好 5 连
                run_len = self._max_run_len(x, y, dx, dy, player)
                self.board[qx][qy] = 0       # 撤回

                if run_len == 5:
                    has_four_in_dir = True
                    break  # 这一方向已经确定有一个“四”了

            if has_four_in_dir:
                count += 1

        return count

    def _find_threes_from_move(self, x: int, y: int, player: int) -> List[List[Tuple[int, int]]]:
        """
        返回一个列表 threes：
        threes 中的每个元素都是一个三，对应一条方向；
        其中存的是这个“三”的所有 vital points（可能 1 个，也可能 2 个）。
        """
        threes: List[List[Tuple[int, int]]] = []

        for dx, dy in self._DIRS:
            vital_points: List[Tuple[int, int]] = []

            # 收集这一方向上的所有格子坐标（含当前落子）
            coords: List[Tuple[int, int]] = []
            cx, cy = x, y
            while self.is_inside(cx, cy):
                coords.append((cx, cy))
                cx -= dx
                cy -= dy
            coords.reverse()
            cx, cy = x + dx, y + dy
            while self.is_inside(cx, cy):
                coords.append((cx, cy))
                cx += dx
                cy += dy

            # 在这条线上找“能让 (x,y) 参与直四、且无五连”的 vital point
            for (qx, qy) in coords:
                if self.board[qx][qy] != 0:
                    continue

                # 假想在 (qx, qy) 再下一个子，看看是否构成“直四”但不五连
                self.board[qx][qy] = player

                # 不能产生五连（否则不算三，而是“直接赢或四四/长连相关”）
                if self._has_five_anywhere(player):
                    self.board[qx][qy] = 0
                    continue

                # 必须产生一个包含 (x, y) 的直四
                if self._has_straight_four_including(x, y, player):
                    vital_points.append((qx, qy))

                self.board[qx][qy] = 0

            if vital_points:
                threes.append(vital_points)

        return threes

    def _three_is_real(self, vital_points: List[Tuple[int, int]], depth: int) -> bool:
        """
        一个三要想算“真的三”，只要它的 vital_points 里
        至少有一个点，是接下来“允许落子”的（考虑递归）。
        """
        for (qx, qy) in vital_points:
            if self.board[qx][qy] != 0:
                continue
            if self._is_vital_point_legal(qx, qy, depth):
                return True
        return False

    def _is_vital_point_legal(self, px: int, py: int, depth: int) -> bool:
        """
        判断在当前局面下，黑棋在 vital point (px, py) 落子是否“允许”。
        这里考虑了：
        - 五连（允许）
        - 长连 / 四四（直接禁）
        - 三三：如果是“真实 double-three”，则禁；否则允许。
        depth：递归深度，防止无限展开，通常设为 2 或 3。
        """
        assert self.board[px][py] == 0
        self.board[px][py] = 1  # 假想在 vital point 落子

        # 1) 检查五连 / 长连
        has_five = False
        has_overline = False
        for dx, dy in self._DIRS:
            run_len = self._max_run_len(px, py, dx, dy, 1)
            if run_len == 5:
                has_five = True
            elif run_len >= 6:
                has_overline = True

        if has_five:
            # 五连优先：这步是赢棋，不看禁手
            self.board[px][py] = 0
            return True

        if has_overline:
            self.board[px][py] = 0
            return False

        # 2) 四四：如果形成 double-four，则这个 vital point 不合法
        fours = self._count_fours_from_move(px, py, 1)
        if fours >= 2:
            self.board[px][py] = 0
            return False

        # 3) 三三（可能需要递归）
        threes = self._find_threes_from_move(px, py, 1)

        if len(threes) < 2:
            # 不是 double-three，这步就算合法
            self.board[px][py] = 0
            return True

        # 现在：这步看起来是 double-three，需要判断是“真三三”还是“假三三”

        if depth <= 0:
            # 深度耗尽，保守起见：把它当“真三三”（禁手）
            self.board[px][py] = 0
            return False

        # 递归地检查这步形成的每一个“三”是否“真的三”
        real_three_count = 0
        for vital_points in threes:
            if self._three_is_real(vital_points, depth - 1):
                real_three_count += 1
                if real_three_count >= 2:
                    # 真正的 double-three：这个 vital point 不合法
                    self.board[px][py] = 0
                    return False

        # 形成的“真的三”不足两个 → 虽然表面看是 double-three，其实是假禁手，这步允许
        self.board[px][py] = 0
        return True

    def _is_forbidden_move_black(self, x: int, y: int) -> bool:
        assert self.board[x][y] == 0
        self.board[x][y] = 1  # 假想落子

        # ---- 1. 五连 / 长连 ----
        has_five = False
        has_overline = False
        for dx, dy in self._DIRS:
            run_len = self._max_run_len(x, y, dx, dy, 1)
            if run_len == 5:
                has_five = True
            elif run_len >= 6:
                has_overline = True

        if has_five:
            self.board[x][y] = 0
            return False  # 赢棋优先，不看禁手

        if has_overline:
            self.board[x][y] = 0
            return True

        # ---- 2. 四四禁手 ----
        fours = self._count_fours_from_move(x, y, 1)
        if fours >= 2:
            self.board[x][y] = 0
            return True

        # ---- 3. 三三禁手（含递归）----
        threes = self._find_threes_from_move(x, y, 1)

        if len(threes) < 2:
            self.board[x][y] = 0
            return False  # 不构成 double-three

        # 用递归逻辑，只统计“真的三”的数量
        real_three_count = 0
        for vital_points in threes:
            if self._three_is_real(vital_points, depth=2):  # depth 可以调大一点，比如 3
                real_three_count += 1
                if real_three_count >= 2:
                    self.board[x][y] = 0
                    return True

        self.board[x][y] = 0
        return False

    def is_valid_move(self, x: int, y: int) -> bool:
        if not (self.is_inside(x, y) and self.is_empty(x, y)):
            return False

        # 根据步数判断当前轮到谁走：
        current_player = 1 if (self.move_count % 2 == 0) else 2

        # 禁手只对黑棋（先手）生效
        if current_player == 1:
            if self._is_forbidden_move_black(x, y):
                return False

        return True

    # 执行落子
    def place_stone(self, x: int, y: int, player: int) -> bool:
        # 检查玩家是否合法
        if player not in (1, 2):
            raise ValueError("player must be 1 or 2")
        # 检查是否是有效落子
        if not self.is_valid_move(x, y):
            return False
        self.board[x][y] = player
        self.move_count += 1
        return True

    # 获取棋盘上 (x, y) 的值
    def get_cell(self, x: int, y: int) -> int:
        if not self.is_inside(x, y):
            raise ValueError(f"({x}, {y}) is outside the board")
        return self.board[x][y]

    # 胜负平局判定
    def check_five_in_a_row(
        self, player: int
    ) -> Optional[List[Tuple[int, int]]]:
        n = self.size

        # 1. 横向（x 增加，y 不变）
        for x in range(n - 4):
            for y in range(n):
                if (
                    self.board[x][y] == player
                    and self.board[x + 1][y] == player
                    and self.board[x + 2][y] == player
                    and self.board[x + 3][y] == player
                    and self.board[x + 4][y] == player
                ):
                    return [(x + i, y) for i in range(5)]

        # 2. 纵向（y 增加，x 不变）
        for x in range(n):
            for y in range(n - 4):
                if (
                    self.board[x][y] == player
                    and self.board[x][y + 1] == player
                    and self.board[x][y + 2] == player
                    and self.board[x][y + 3] == player
                    and self.board[x][y + 4] == player
                ):
                    return [(x, y + i) for i in range(5)]

        # 3. 左上 → 右下 斜线
        for x in range(n - 4):
            for y in range(n - 4):
                if (
                    self.board[x][y] == player
                    and self.board[x + 1][y + 1] == player
                    and self.board[x + 2][y + 2] == player
                    and self.board[x + 3][y + 3] == player
                    and self.board[x + 4][y + 4] == player
                ):
                    return [(x + i, y + i) for i in range(5)]
        # 4. 右上 → 左下 斜线

        for x in range(n - 4):
            for y in range(n - 4):
                if (
                    self.board[x + 4][y] == player
                    and self.board[x + 3][y + 1] == player
                    and self.board[x + 2][y + 2] == player
                    and self.board[x + 1][y + 3] == player
                    and self.board[x][y + 4] == player
                ):
                    return [(x + 4 - i, y + i) for i in range(5)]
        return None

    # 判断棋盘是否已经下满
    def is_full(self) -> bool:
        return self.move_count >= self.size * self.size

    # 判断当前棋局状态
    def get_game_result(
        self, with_line: bool = False
    ):
        # 玩家 1 是否赢
        line1 = self.check_five_in_a_row(1)
        if line1 is not None:
            return (1, line1) if with_line else 1

        # 玩家 2 是否赢
        line2 = self.check_five_in_a_row(2)
        if line2 is not None:
            return (2, line2) if with_line else 2

        # 是否平局
        if self.is_full():
            return (3, [(-1, -1)]) if with_line else 3

        # 游戏进行中
        return (0, [(-1, -1)]) if with_line else 0

    # 棋盘序列化，将当前棋盘编码为字符串
    def to_string(self) -> str:
        codes: List[str] = []
        for x in range(self.size):
            for y in range(self.size):
                codes.append(str(self.board[x][y]))
        return "".join(codes)


### Write `backend/models/game_engine.py`

In [None]:
%%writefile backend/models/game_engine.py
from typing import List, Tuple, Dict, Any, Optional
from backend.models.board import Board

class GameEngine:

    # 构造函数
    def __init__(self, size: int = 15, first_player: int = 1) -> None:

        # 棋盘对象
        self.board: Board = Board(size=size)

        # 当前该谁落子
        self.current_player: int = first_player
        self.first_player: int = first_player  # 记录先手

        # 游戏状态，0=未结束，1/2=某方胜，3=平局
        self.game_over: bool = False
        self.winner: int = 0

        # 落子历史：[(x, y, player), ...]
        self.move_history: List[Tuple[int, int, int]] = []

        # 统计数据
        self.total_games: int = 0
        self.black_wins: int = 0
        self.white_wins: int = 0
        self.draws: int = 0

    # 重置一局游戏，但保留整体统计
    def reset_game(self, first_player: Optional[int] = None) -> None:

        self.board.reset()
        if first_player is not None:
            self.first_player = first_player
        self.current_player = self.first_player

        self.game_over = False
        self.winner = 0
        self.move_history.clear()

    # 切换当前玩家
    def _switch_player(self) -> None:
        self.current_player = 3 - self.current_player

    # 落子逻辑
    def make_move(self, x: int, y: int) -> bool:

        if self.game_over:
            return False

        # 尝试在棋盘上落子
        success = self.board.place_stone(x, y, self.current_player)
        if not success:
            return False

        # 记录历史
        self.move_history.append((x, y, self.current_player))

        # 判断游戏结果
        result = self.board.get_game_result(with_line=False)
        if result != 0:
            # 游戏结束
            self.game_over = True
            self.winner = result
            self._update_statistics(result)
        else:
            # 游戏未结束，轮到另一个玩家
            self._switch_player()

        return True

    # 由指定玩家在 (x, y) 落子，用于算法自对弈
    def make_move_for(self, x: int, y: int, player: int) -> bool:

        if self.game_over:
            return False

        if player not in (1, 2):
            raise ValueError("player must be 1 or 2")

        success = self.board.place_stone(x, y, player)
        if not success:
            return False

        self.move_history.append((x, y, player))

        result = self.board.get_game_result(with_line=False)
        if result != 0:
            self.game_over = True
            self.winner = result
            self._update_statistics(result)

        return True

    # 更新统计数据
    def _update_statistics(self, result: int) -> None:

        if result == 0:
            return

        self.total_games += 1

        if result == 1:
            self.black_wins += 1
        elif result == 2:
            self.white_wins += 1
        elif result == 3:
            self.draws += 1

    # ---------- 状态查询接口 ----------

    # 返回当前游戏的整体状态
    def get_status(self) -> Dict[str, Any]:
        return {
            "board_size": self.board.size,
            "board": self.board.board,  # 直接给出二维数组
            "current_player": self.current_player,
            "game_over": self.game_over,
            "winner": self.winner,  # 0/1/2/3
            "move_count": self.board.move_count,
            "move_history": list(self.move_history),
            "statistics": {
                "total_games": self.total_games,
                "black_wins": self.black_wins,
                "white_wins": self.white_wins,
                "draws": self.draws,
            },
        }

    # 返回最近一次落子
    def get_last_move(self) -> Optional[Tuple[int, int, int]]:
        if not self.move_history:
            return None
        return self.move_history[-1]

    # 简单打印当前棋盘到终端，仅调试用
    def debug_print_board(self) -> None:
        size = self.board.size
        for y in range(size):
            row = []
            for x in range(size):
                v = self.board.board[x][y]
                if v == 0:
                    row.append(".")
                elif v == 1:
                    row.append("●")  # 黑子
                else:
                    row.append("○")  # 白子
            print(" ".join(row))
        print(f"current_player = {self.current_player}, game_over = {self.game_over}, winner = {self.winner}")

### Write `backend/algorithms/mcts_ai.py` (Dependency for Classic AI)

In [None]:
%%writefile backend/algorithms/mcts_ai.py
import copy
import random
from typing import List, Tuple
from backend.models.board import Board

def get_neighbor_moves(board: Board, distance: int = 2) -> List[Tuple[int, int]]:
    if board.move_count == 0:
        return [(board.size // 2, board.size // 2)]

    moves = set()
    size = board.size
    board_map = board.board

    for x in range(size):
        for y in range(size):
            if board_map[x][y] != 0:
                x_min = max(0, x - distance)
                x_max = min(size, x + distance + 1)
                y_min = max(0, y - distance)
                y_max = min(size, y + distance + 1)

                for nx in range(x_min, x_max):
                    for ny in range(y_min, y_max):
                        if board_map[nx][ny] == 0 and board.is_valid_move(nx, ny):
                            moves.add((nx, ny))

    return list(moves)


### Write `backend/algorithms/classic_ai.py`

In [None]:
%%writefile backend/algorithms/classic_ai.py
import json
import math
import random
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from backend.algorithms.mcts_ai import get_neighbor_moves
from backend.models.board import Board


@dataclass
class SearchMetrics:
    """Lightweight search statistics for benchmarking."""

    elapsed_ms: float
    explored_nodes: int
    candidate_moves: int


def load_ai_config(path: str) -> Dict[str, Any]:
    """
    Load hyperparameter configuration from a JSON file.
    """
    config_path = Path(path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found: {config_path}")
    with config_path.open("r", encoding="utf-8") as f:
        return json.load(f)


def random_move(board: Board, distance: int = 2) -> Tuple[int, int]:
    """
    Return a random legal move near existing stones (distance-limited neighborhood).
    Falls back to board center if nothing is available.
    """
    candidates = get_neighbor_moves(board, distance=distance)
    if not candidates:
        center = (board.size // 2, board.size // 2)
        return center
    return random.choice(candidates)


class GreedyAgent:
    """
    Optimized One-ply greedy agent using local evaluation.
    """

    def __init__(self, distance: int = 2):
        self.distance = distance
        self.last_metrics: Optional[SearchMetrics] = None

    def get_move(self, board: Board, player: int) -> Tuple[int, int]:
        start = time.perf_counter()
        candidates = get_neighbor_moves(board, distance=self.distance)
        if not candidates:
            return (board.size // 2, board.size // 2)

        best_score = -math.inf
        best_moves = [] 
        explored = 0

        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]

        for mx, my in candidates:
            if not board.is_valid_move(mx, my):
                continue
            explored += 1
            
            # FAST LOCAL EVALUATION
            # Instead of full board scan, we evaluate only lines passing through (mx, my)
            current_score = 0
            
            # 1. Offense Score (My gain)
            board.board[mx][my] = player # Assume move
            for dx, dy in directions:
                count = 1
                tx, ty = mx + dx, my + dy
                while board.is_inside(tx, ty) and board.board[tx][ty] == player:
                    count += 1; tx += dx; ty += dy
                f_open = board.is_inside(tx, ty) and board.board[tx][ty] == 0
                
                tx, ty = mx - dx, my - dy
                while board.is_inside(tx, ty) and board.board[tx][ty] == player:
                    count += 1; tx -= dx; ty -= dy
                b_open = board.is_inside(tx, ty) and board.board[tx][ty] == 0
                
                if count >= 5: current_score += 100_000
                elif count == 4: current_score += 10000 if (f_open and b_open) else 1000
                elif count == 3: current_score += 1000 if (f_open and b_open) else 100
                elif count == 2: current_score += 100 if (f_open and b_open) else 10
            
            # 2. Defense Score (Blocking opponent)
            board.board[mx][my] = 3 - player # Assume opponent move
            for dx, dy in directions:
                count = 1
                tx, ty = mx + dx, my + dy
                while board.is_inside(tx, ty) and board.board[tx][ty] == (3 - player):
                    count += 1; tx += dx; ty += dy
                f_open = board.is_inside(tx, ty) and board.board[tx][ty] == 0
                
                tx, ty = mx - dx, my - dy
                while board.is_inside(tx, ty) and board.board[tx][ty] == (3 - player):
                    count += 1; tx -= dx; ty -= dy
                b_open = board.is_inside(tx, ty) and board.board[tx][ty] == 0
                
                if count >= 5: current_score += 90_000
                elif count == 4: current_score += 9000 if (f_open and b_open) else 900
                elif count == 3: current_score += 900 if (f_open and b_open) else 90
            
            board.board[mx][my] = 0 # Restore
            
            if current_score > best_score:
                best_score = current_score
                best_moves = [(mx, my)]
            elif current_score == best_score:
                best_moves.append((mx, my))

        elapsed_ms = (time.perf_counter() - start) * 1000
        self.last_metrics = SearchMetrics(
            elapsed_ms=elapsed_ms, explored_nodes=explored, candidate_moves=len(candidates)
        )
        return random.choice(best_moves) if best_moves else (board.size // 2, board.size // 2)


### Write `backend/algorithms/qlearning_ai.py`

In [None]:
%%writefile backend/algorithms/qlearning_ai.py
import os
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from typing import Tuple, Optional, Any
from stable_baselines3 import DQN
from backend.models.game_engine import GameEngine
from backend.models.board import Board
from backend.algorithms.classic_ai import GreedyAgent

class GomokuEnv(gym.Env):
    metadata = {"render_modes": ["human"]}

    def __init__(self, board_size=15, opponent_ai=None, reward_type='heuristic', invalid_penalty=-50.0):
        super().__init__()
        self.board_size = board_size
        self.engine = GameEngine(size=board_size)
        self.opponent_ai = opponent_ai if opponent_ai is not None else GreedyAgent()
        self.reward_type = reward_type
        self.invalid_penalty = invalid_penalty
        self.observation_space = spaces.Box(
            low=0, high=2, shape=(board_size, board_size), dtype=np.float32
        )
        self.action_space = spaces.Discrete(board_size * board_size)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.engine.reset_game()
        return self._get_obs(), {}

    def step(self, action: int):
        x, y = divmod(int(action), self.board_size)

        if not self.engine.board.is_valid_move(x, y):
            return self._get_obs(), self.invalid_penalty, True, False, {"error": "Invalid"}

        self.engine.make_move(x, y) # AI (Player 1) moves

        current_reward = 0.0
        done = False

        if self.engine.game_over:
            if self.engine.winner == 1:
                current_reward = 100.0
                done = True
            else:
                current_reward = 0.0
                done = True
        else:
            if self.reward_type == 'heuristic':
                current_reward += self._calculate_heuristic_reward(x, y, player=1)
                if self.engine.board.move_count < 10:
                    center = self.board_size // 2
                    if abs(x - center) + abs(y - center) < 4:
                        current_reward += 0.5
            else:
                current_reward = 0.0

        if done:
            return self._get_obs(), current_reward, done, False, {}

        # Opponent (Player 2) moves
        opp_x, opp_y = self.opponent_ai.get_move(self.engine.board, 2)
        if opp_x != -1:
            self.engine.make_move(opp_x, opp_y)

        if self.engine.game_over:
            if self.engine.winner == 2:
                return self._get_obs(), -100.0, True, False, {}
            return self._get_obs(), 0.0, True, False, {}

        return self._get_obs(), current_reward, False, False, {}

    def _calculate_heuristic_reward(self, x, y, player):
        score = 0.0
        board = self.engine.board
        opponent = 3 - player
        
        directions = [(1, 0), (0, 1), (1, 1), (1, -1)]
        for dx, dy in directions:
            count = 1
            tx, ty = x + dx, y + dy
            while board.is_inside(tx, ty) and board.board[tx][ty] == player:
                count += 1
                tx += dx
                ty += dy
            tx, ty = x - dx, y - dy
            while board.is_inside(tx, ty) and board.board[tx][ty] == player:
                count += 1
                tx -= dx
                ty -= dy
            
            if count == 5:
                score += 50.0
            elif count == 4:
                score += 15.0
            elif count == 3:
                score += 5.0
            elif count == 2:
                score += 1.0

        # Blocking rewards
        board.board[x][y] = opponent
        for dx, dy in directions:
            count = 1
            tx, ty = x + dx, y + dy
            while board.is_inside(tx, ty) and board.board[tx][ty] == opponent:
                count += 1
                tx += dx
                ty += dy
            tx, ty = x - dx, y - dy
            while board.is_inside(tx, ty) and board.board[tx][ty] == opponent:
                count += 1
                tx -= dx
                ty -= dy
            
            if count >= 5:
                score += 50.0
            elif count == 4:
                score += 15.0
            elif count == 3:
                score += 5.0
        
        board.board[x][y] = player
        return score

    def _get_obs(self):
        return np.array(self.engine.board.board, dtype=np.float32)

class QLearningAgent:
    def __init__(self, model_path: str = "models/dqn_gomoku", train_mode: bool = False):
        self.model_path = model_path
        self.model: Optional[DQN] = None
        self.load_model()

    def load_model(self):
        if os.path.exists(self.model_path + ".zip"):
            try:
                self.model = DQN.load(self.model_path)
            except Exception as e:
                pass

    def get_move(self, board: Board, player: int) -> Tuple[int, int]:
        if self.model is None:
            from backend.algorithms.classic_ai import random_move
            return random_move(board)
        obs = np.array(board.board, dtype=np.float32)
        action, _ = self.model.predict(obs, deterministic=True)
        x, y = divmod(int(action), 15)
        if not board.is_valid_move(x, y):
            from backend.algorithms.classic_ai import random_move
            return random_move(board)
        return (x, y)


### Write Training Script

In [None]:
%%writefile scripts/train_dqn_v2.py
import sys
import os
import torch

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from backend.algorithms.qlearning_ai import GomokuEnv
from backend.algorithms.classic_ai import GreedyAgent
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback

def train_dqn_v2():
    # Configuration
    BOARD_SIZE = 15
    MODEL_PATH = "models/dqn_gomoku_v2"
    TOTAL_TIMESTEPS = 2_000_000 # Increased for Kaggle (faster GPU)
    
    os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")
    print(f"Target Timesteps: {TOTAL_TIMESTEPS}")

    # Opponent (GreedyAgent) is now optimized for speed in notebook
    opponent = GreedyAgent(distance=2)
    env = GomokuEnv(
        board_size=BOARD_SIZE,
        opponent_ai=opponent,
        reward_type='heuristic',
        invalid_penalty=-200.0
    )

    policy_kwargs = dict(net_arch=[512, 512, 256])

    model = DQN(
        "MlpPolicy",
        env,
        verbose=1,
        device=device,
        batch_size=512,
        learning_rate=1e-4,
        buffer_size=100_000,
        exploration_fraction=0.1,
        exploration_initial_eps=1.0,
        exploration_final_eps=0.05,
        policy_kwargs=policy_kwargs
    )

    checkpoint_callback = CheckpointCallback(
        save_freq=100_000,
        save_path='./data/models/checkpoints_v2/',
        name_prefix='dqn_v2'
    )

    print("Training started on Kaggle GPU...")
    model.learn(
        total_timesteps=TOTAL_TIMESTEPS,
        callback=checkpoint_callback,
        tb_log_name="run_v2_kaggle"
    )

    model.save(MODEL_PATH)
    print(f"Training completed. Model saved to {MODEL_PATH}.zip")

if __name__ == "__main__":
    train_dqn_v2()

## Step 3: Run Training

In [None]:
!python scripts/train_dqn_v2.py

## Step 4: Download Model
Run this cell to zip the model and checkpoints for download.

In [None]:
!zip -r gomoku_models.zip models data