In [1]:
!pip install gymnasium numpy torch ray

Collecting ray
  Downloading ray-2.47.1-cp311-cp311-manylinux2014_x86_64.whl.metadata (20 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)


In [59]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from collections import deque
import numpy as np
import random
import os
import time
import math
import gymnasium as gym
from gymnasium import spaces
import datetime
import shutil
import gc
from typing import List, Dict, Tuple

In [60]:

# ========== PYTORCH SETUP ==========
# เพิ่มความเร็ว
torch.backends.cudnn.benchmark = True
# สำหรับ GPU ใหม่
torch.set_float32_matmul_precision('high')


In [61]:
# ========== ENHANCED METRICS & LOGGING ==========
class MetricsLogger:
    def __init__(self):
        self.iteration_metrics = []

    def log_iteration_start(self, iteration: int):
        self.current_iter = {
            'iteration': iteration,
            'start_time': time.perf_counter(),
            'self_play_time': 0,
            'training_time': 0,
            'evaluation_time': 0
        }

    def log_self_play_metrics(self, stats: List[Dict], examples_count: int, self_play_time: float):
        self.current_iter['self_play_time'] = self_play_time
        self.current_iter['examples_count'] = examples_count
        self.current_iter['game_stats'] = stats

    def log_training_metrics(self, epoch_metrics: List[Dict], training_time: float):
        self.current_iter['training_time'] = training_time
        self.current_iter['epoch_metrics'] = epoch_metrics

    def log_evaluation_metrics(self, eval_results: Dict, eval_time: float):
        self.current_iter['evaluation_time'] = eval_time
        self.current_iter['eval_results'] = eval_results

    def finish_iteration(self):
        self.current_iter['total_time'] = time.perf_counter() - self.current_iter['start_time']
        self.iteration_metrics.append(self.current_iter.copy())
        return self.current_iter

In [62]:
# ========== NEURAL NETWORK ==========
class ResNetBlock(nn.Module):
    def __init__(self, num_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class MakNeebNet(nn.Module):
    def __init__(self, board_size=8, num_res_blocks=4, num_channels=32, action_size=4096):
        super(MakNeebNet, self).__init__()
        self.conv_in = nn.Conv2d(1, num_channels, kernel_size=3, padding=1)
        self.bn_in = nn.BatchNorm2d(num_channels)
        self.res_blocks = nn.ModuleList([ResNetBlock(num_channels) for _ in range(num_res_blocks)])

        # Policy head
        self.policy_conv = nn.Conv2d(num_channels, 2, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(2 * board_size * board_size, action_size)

        # Value head
        self.value_conv = nn.Conv2d(num_channels, 1, kernel_size=1)
        # BatchNorm is generally more stable than LayerNorm for this type of input
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(1 * board_size * board_size, 128)
        self.value_fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = x.unsqueeze(1) # Add channel dimension
        x = F.relu(self.bn_in(self.conv_in(x)))

        for block in self.res_blocks:
            x = block(x)

        # Policy head
        policy = F.relu(self.policy_bn(self.policy_conv(x)))
        policy = policy.view(policy.size(0), -1)
        policy = self.policy_fc(policy)

        # Value head
        value = F.relu(self.value_bn(self.value_conv(x)))
        value = value.view(value.size(0), -1)
        value = F.relu(self.value_fc1(value))
        value = torch.tanh(self.value_fc2(value))

        return F.log_softmax(policy, dim=1), value


In [63]:
# ========== ENVIRONMENT ==========
class MakNeebRLEnv(gym.Env):
    def __init__(self):
        super(MakNeebRLEnv, self).__init__()
        self.board_size = 8
        self.action_space = spaces.Discrete(self.board_size * self.board_size * self.board_size * self.board_size)
        self.observation_space = spaces.Box(low=-1, high=1, shape=(self.board_size, self.board_size), dtype=np.int8)
        self.reset()

    def _encode_action(self, from_row, from_col, to_row, to_col):
        return from_row * (8*8*8) + from_col * (8*8) + to_row * 8 + to_col

    def _decode_action(self, action):
        from_row = action // (8*8*8)
        action %= (8*8*8)
        from_col = action // (8*8)
        action %= (8*8)
        to_row = action // 8
        to_col = action % 8
        return (from_row, from_col), (to_row, to_col)

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
        self.board[0, :] = 1
        self.board[7, :] = -1
        self.current_player = 1
        self.turns_without_capture = 0
        self.max_turns_without_capture = 50
        return self.board.copy(), {"current_player": self.current_player}

    def get_legal_actions(self):
        legal_actions = []
        for r_from in range(self.board_size):
            for c_from in range(self.board_size):
                if self.board[r_from, c_from] == self.current_player:
                    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
                        for i in range(1, self.board_size):
                            r_to, c_to = r_from + dr * i, c_from + dc * i
                            if not (0 <= r_to < self.board_size and 0 <= c_to < self.board_size):
                                break
                            if self.board[r_to, c_to] != 0:
                                break
                            action = self._encode_action(r_from, c_from, r_to, c_to)
                            legal_actions.append(action)
        return legal_actions

    def step(self, action):
        legal_actions = self.get_legal_actions()
        if action not in legal_actions:
            # Assign a large penalty for an illegal move and end the game
            return self.board.copy(), -10.0, True, False, {"error": "Illegal move"}

        (from_row, from_col), (to_row, to_col) = self._decode_action(action)
        self.board[to_row, to_col] = self.current_player
        self.board[from_row, from_col] = 0

        captured_count = self._check_and_capture(to_row, to_col)
        reward = captured_count * 2.0

        if captured_count > 0:
            self.turns_without_capture = 0
        else:
            self.turns_without_capture += 1

        done, winner = self.get_game_status()
        if done:
            if winner == self.current_player:
                reward += 20.0
            elif winner == -self.current_player:
                reward -= 20.0
            elif winner == 0: # Draw
                reward += 5.0

        self.current_player *= -1
        info = {"current_player": self.current_player, "captured": captured_count > 0}
        return self.board.copy(), float(reward), done, False, info

    def _check_and_capture(self, r, c):
        total_captured = 0
        opponent = -self.current_player

        # Check both horizontal and vertical lines passing through (r, c)
        for dr, dc in [(1, 0), (0, 1)]:
            line_pieces = []
            if dr == 1: # Vertical line
                for i in range(self.board_size):
                    if self.board[i, c] != 0:
                        line_pieces.append({'player': self.board[i, c], 'pos': (i, c)})
            else: # Horizontal line
                for i in range(self.board_size):
                    if self.board[r, i] != 0:
                        line_pieces.append({'player': self.board[r, i], 'pos': (r, i)})

            if len(line_pieces) < 2:
                continue

            captured_in_this_line = set()
            my_indices = [i for i, p in enumerate(line_pieces) if p['player'] == self.current_player]

            if len(my_indices) >= 2:
                for i in range(len(my_indices) - 1):
                    start_idx, end_idx = my_indices[i], my_indices[i+1]
                    if end_idx > start_idx + 1:
                        # Check if all pieces between are opponent's pieces
                        is_all_opponent = all(line_pieces[k]['player'] == opponent
                                              for k in range(start_idx + 1, end_idx))
                        if is_all_opponent:
                            for k in range(start_idx + 1, end_idx):
                                captured_in_this_line.add(line_pieces[k]['pos'])

            # Apply captures
            for pos_r, pos_c in captured_in_this_line:
                if self.board[pos_r, pos_c] == opponent: # Ensure we don't accidentally remove our own piece
                    self.board[pos_r, pos_c] = 0
                    total_captured += 1

        return total_captured


    def get_game_status(self):
        player1_pieces = np.sum(self.board == 1)
        player_minus_1_pieces = np.sum(self.board == -1)

        if player1_pieces == 0:
            return True, -1
        if player_minus_1_pieces == 0:
            return True, 1
        if not self.get_legal_actions_for_player(-self.current_player):
            # The next player has no legal moves, so the current player wins.
            return True, self.current_player

        if self.turns_without_capture >= self.max_turns_without_capture:
            if player1_pieces > player_minus_1_pieces:
                return True, 1
            elif player_minus_1_pieces > player1_pieces:
                return True, -1
            else: # Draw
                return True, 0

        return False, 0 # Game not over

    def get_legal_actions_for_player(self, player):
        original_player = self.current_player
        self.current_player = player
        legal_actions = self.get_legal_actions()
        self.current_player = original_player
        return legal_actions

    def copy(self):
        new_env = MakNeebRLEnv()
        new_env.board = self.board.copy()
        new_env.current_player = self.current_player
        new_env.turns_without_capture = self.turns_without_capture
        return new_env


In [64]:
# ========== MCTS ==========
class MCTSNode:
    def __init__(self, parent=None, prior_p=1.0):
        self.parent = parent
        self.children = {}
        self.visit_count = 0
        self.value_sum = 0
        self.prior_p = prior_p

    def expand(self, action_priors: List[Tuple[int, float]]):
        for action, prior in action_priors:
            if action not in self.children:
                self.children[action] = MCTSNode(parent=self, prior_p=prior)

    def select(self, c_puct: float) -> Tuple[int, 'MCTSNode']:
        best_score = -float('inf')
        best_action = -1
        best_child = None

        for action, child in self.children.items():
            score = self._get_ucb_score(child, c_puct)
            if score > best_score:
                best_score, best_action, best_child = score, action, child

        return best_action, best_child

    def _get_ucb_score(self, child: 'MCTSNode', c_puct: float) -> float:
        # Q-value is from the child's perspective
        q_value = -child.value()

        # U-value encourages exploration
        u_value = c_puct * child.prior_p * math.sqrt(self.visit_count) / (1 + child.visit_count)

        return q_value + u_value

    def value(self) -> float:
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0

    def backpropagate(self, value: float):
        self.visit_count += 1
        # The value is always from the perspective of the current player at this node
        self.value_sum += value
        if self.parent:
            # The value must be inverted for the parent, as it's the other player's turn
            self.parent.backpropagate(-value)

class MCTS:
    def __init__(self, model: nn.Module, device: torch.device, c_puct: float = 1.5, num_simulations: int = 50):
        self.model = model
        self.device = device
        self.c_puct = c_puct
        self.num_simulations = num_simulations

    @torch.no_grad()
    def search(self, env: MakNeebRLEnv) -> np.ndarray:
        root = MCTSNode()

        # At the start of a search, the root represents the state for the current player
        initial_board_tensor = torch.tensor(
            env.board * env.current_player, dtype=torch.float32
        ).unsqueeze(0).to(self.device)

        log_policy, value_tensor = self.model(initial_board_tensor)
        policy = torch.exp(log_policy).squeeze(0).cpu().numpy()
        value = value_tensor.item()

        legal_actions = env.get_legal_actions()
        if not legal_actions:
            return np.zeros(env.action_space.n)

        action_priors = [(action, policy[action]) for action in legal_actions]
        root.expand(action_priors)

        for _ in range(self.num_simulations):
            node = root
            sim_env = env.copy()

            # --- Selection ---
            while node.children:
                action, node = node.select(self.c_puct)
                sim_env.step(action)

            # --- Expansion & Evaluation ---
            done, winner = sim_env.get_game_status()
            value = 0.0

            if not done:
                board_tensor = torch.tensor(
                    sim_env.board * sim_env.current_player, dtype=torch.float32
                ).unsqueeze(0).to(self.device)

                log_policy, value_tensor = self.model(board_tensor)
                policy = torch.exp(log_policy).squeeze(0).cpu().numpy()
                value = value_tensor.item()

                legal_actions_sim = sim_env.get_legal_actions()
                if legal_actions_sim:
                    action_priors_sim = [(action, policy[action]) for action in legal_actions_sim]
                    node.expand(action_priors_sim)
            else:
                # Terminal node, value is determined by game result
                if winner == sim_env.current_player:
                    value = 1.0
                elif winner == -sim_env.current_player:
                    value = -1.0
                # if winner is 0 (draw), value remains 0.0

            # --- Backpropagation ---
            # Value is from the perspective of the player at the expanded node.
            # backpropagate handles negating it for parents.
            node.backpropagate(value)

        # Return visit count distribution as the policy
        visit_counts = np.zeros(env.action_space.n)
        for action, child in root.children.items():
            visit_counts[action] = child.visit_count

        if np.sum(visit_counts) == 0:
            return np.ones(env.action_space.n) / env.action_space.n

        action_probs = visit_counts / np.sum(visit_counts)
        return action_probs


In [66]:
# ========== PARALLEL SELF-PLAY WORKER ==========

def play_game(model_state_dict: dict, mcts_sims: int, c_puct: float, temperature: float, noise_alpha: float) -> Tuple[List[Tuple], Dict]:
    """
    Plays a single game of self-play on the CPU.
    This function is executed by each worker process.
    """
    # 1. สร้างโมเดลและ MCTS ภายใน Worker แต่ละตัว (ทำงานบน CPU)
    device = torch.device('cpu')
    model = MakNeebNet().to(device)
    model.load_state_dict(model_state_dict)
    model.eval()
    mcts = MCTS(model, device, c_puct, mcts_sims)
    env = MakNeebRLEnv()

    game_history = []
    policy_entropies = []

    # 2. วนลูปเล่นเกมจนจบ
    while True:
        board_state = env.board * env.current_player
        action_probs = mcts.search(env)

        # เพิ่ม Dirichlet noise เพื่อการสำรวจ
        legal_actions = env.get_legal_actions()
        if legal_actions:
            noise = np.random.dirichlet([noise_alpha] * len(legal_actions))
            for i, action in enumerate(legal_actions):
                action_probs[action] = 0.75 * action_probs[action] + 0.25 * noise[i]
            action_probs /= np.sum(action_probs)

        # ใช้ Temperature เพื่อควบคุมการเลือกตาเดิน
        temp = temperature if len(game_history) < 15 else 0
        if temp > 0:
            temp_probs = action_probs ** (1.0 / temp)
            action = np.random.choice(len(temp_probs), p=temp_probs/np.sum(temp_probs))
        else:
            action = np.argmax(action_probs)

        valid_probs = action_probs[action_probs > 1e-8]
        policy_entropies.append(-np.sum(valid_probs * np.log(valid_probs)))

        game_history.append((board_state, action_probs, env.current_player))

        _, _, done, _, _ = env.step(action)
        if done:
            break

    # 3. สร้างข้อมูลการฝึกหลังจากเกมจบ
    _, winner = env.get_game_status()
    training_examples = []
    for board, probs, player in game_history:
        reward = 0.0 if winner == 0 else (1.0 if winner == player else -1.0)
        training_examples.append((board, probs, reward))

    game_stats = {
        'winner': winner,
        'game_length': len(game_history),
        'avg_policy_entropy': np.mean(policy_entropies) if policy_entropies else 0.0,
    }

    return training_examples, game_stats

class SelfPlayWorker(mp.Process):
    """
    Worker process ที่รับผิดชอบการเล่นเกม (เรียกใช้ play_game)
    """
    def __init__(self, job_queue, data_queue, model_state_dict, mcts_params):
        super().__init__()
        self.job_queue = job_queue
        self.data_queue = data_queue
        self.model_state_dict = model_state_dict
        self.mcts_params = mcts_params

    def run(self):
        while True:
            game_idx = self.job_queue.get()
            if game_idx is None:  # สัญญาณให้หยุดทำงาน
                break

            examples, stats = play_game(self.model_state_dict, **self.mcts_params)
            self.data_queue.put((examples, stats))

def parallel_self_play(model, num_games, num_workers, mcts_params):
    """
    ฟังก์ชันหลักสำหรับจัดการ Self-Play แบบ Parallel
    """
    model.eval()
    # ส่ง state_dict แทน model object ทั้งหมดเพื่อความปลอดภัยใน multiprocessing
    model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

    job_queue = mp.Queue()
    data_queue = mp.Queue()

    for i in range(num_games):
        job_queue.put(i)
    for _ in range(num_workers):
        job_queue.put(None) # สัญญาณให้ Worker หยุด

    workers = [SelfPlayWorker(job_queue, data_queue, model_state_dict, mcts_params) for _ in range(num_workers)]

    print(f"🚀 Launching {num_workers} parallel workers for {num_games} games...")
    for w in workers:
        w.start()

    all_examples, all_stats = [], []
    for _ in range(num_games):
        examples, stats = data_queue.get()
        all_examples.extend(examples)
        all_stats.append(stats)
        if (len(all_stats)) % 10 == 0:
             print(f"  ...collected results from {len(all_stats)}/{num_games} games")


    for w in workers:
        w.join()

    return all_examples, all_stats

In [68]:
# ========== TRAINING & EVALUATION ==========
def enhanced_train_model(model, replay_buffer, optimizer, device, batch_size, num_epochs, entropy_weight):
    model.train()
    buffer_list = list(replay_buffer)
    epoch_metrics_list = []

    for epoch in range(num_epochs):
        random.shuffle(buffer_list)
        epoch_losses = {'total': 0, 'policy': 0, 'value': 0}
        num_batches = 0

        for i in range(0, len(buffer_list), batch_size):
            batch = buffer_list[i:i + batch_size]
            if not batch: continue
            boards, target_policies, target_values = zip(*batch)
            boards = torch.tensor(np.array(boards), dtype=torch.float32).to(device)
            target_policies = torch.tensor(np.array(target_policies), dtype=torch.float32).to(device)
            target_values = torch.tensor(target_values, dtype=torch.float32).unsqueeze(1).to(device)

            optimizer.zero_grad()
            log_policies, predicted_values = model(boards)
            policy_loss = -(target_policies * log_policies).sum(dim=1).mean()
            value_loss = F.mse_loss(predicted_values, target_values)
            policy_probs_clamped = torch.clamp(torch.exp(log_policies), 1e-8, 1.0)
            entropy = -(policy_probs_clamped * torch.log(policy_probs_clamped)).sum(dim=1).mean()
            total_loss = policy_loss + value_loss - entropy * entropy_weight
            total_loss.backward()
            optimizer.step()

            epoch_losses['total'] += total_loss.item()
            epoch_losses['policy'] += policy_loss.item()
            epoch_losses['value'] += value_loss.item()
            num_batches += 1

        if num_batches > 0:
            for key in epoch_losses: epoch_losses[key] /= num_batches
            epoch_metrics_list.append(epoch_losses)
            print(f"  Epoch {epoch + 1}/{num_epochs} | Avg Loss: {epoch_losses['total']:.4f} [P: {epoch_losses['policy']:.4f}, V: {epoch_losses['value']:.4f}]")

    overall_metrics = {k: np.mean([e[k] for e in epoch_metrics_list]) for k in epoch_metrics_list[0]} if epoch_metrics_list else {}
    return overall_metrics

def evaluate_against_baseline(model, device, baseline_agent_func, num_games=20):
    model.eval()
    mcts = MCTS(model, device, c_puct=1.0, num_simulations=40)
    wins, draws, losses = 0, 0, 0
    for _ in range(num_games):
        env = MakNeebRLEnv()
        model_is_player1 = random.choice([True, False])
        while True:
            done, winner = env.get_game_status()
            if done:
                if winner == 0: draws += 1
                elif (winner == 1 and model_is_player1) or (winner == -1 and not model_is_player1): wins += 1
                else: losses += 1
                break
            is_model_turn = (env.current_player == 1 and model_is_player1) or (env.current_player == -1 and not model_is_player1)
            action = np.argmax(mcts.search(env)) if is_model_turn else baseline_agent_func(env)
            if action not in env.get_legal_actions(): action = random.choice(env.get_legal_actions())
            env.step(action)
    return wins, draws, losses

def enhanced_evaluate_model(model, device, num_games_per_opponent=50):
    print("\n🏆 === MODEL EVALUATION ===")
    def run_eval(agent_name, agent_func):
        print(f"Evaluating against {agent_name} Agent ({num_games_per_opponent} games)...")
        wins, draws, losses = evaluate_against_baseline(model, device, agent_func, num_games_per_opponent)
        total = wins + draws + losses
        winrate = wins / total if total > 0 else 0
        print(f"  vs {agent_name}: {wins}W - {draws}D - {losses}L (Win Rate: {winrate:.1%})")
        return (wins + 0.5 * draws) / total if total > 0 else 0

    random_score = run_eval("Random", lambda env: random.choice(env.get_legal_actions()) if env.get_legal_actions() else 0)
    # Re-implement greedy logic here for simplicity
    def greedy_agent(env: MakNeebRLEnv):
        actions = env.get_legal_actions()
        if not actions: return 0
        best_action, max_captures = actions[0], -1
        for action in actions:
            test_env = env.copy()
            board_before = np.sum(np.abs(test_env.board))
            test_env.step(action)
            captures = board_before - np.sum(np.abs(test_env.board))
            if captures > max_captures: max_captures, best_action = captures, action
        return best_action if max_captures > 0 else random.choice(actions)

    greedy_score = run_eval("Greedy", greedy_agent)
    strength = (random_score * 0.4 + greedy_score * 0.6) * 1000
    print(f"  Strength Score: {strength:.1f} / 1000")
    return {'strength': strength, 'random_winrate': random_score, 'greedy_winrate': greedy_score}


In [69]:
# ========== CHECKPOINTING  ==========
def save_checkpoint(iteration, model, optimizer, replay_buffer, path):
    torch.save({
        'iteration': iteration,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'replay_buffer': list(replay_buffer)
    }, path)
    print(f"💾 Checkpoint saved to {path}")

def load_checkpoint(model, optimizer, path, device):
    if not os.path.exists(path):
        print("No checkpoint found, starting from scratch.")
        return 0, deque(maxlen=30000)
    try:
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        replay_buffer = deque(checkpoint['replay_buffer'], maxlen=30000)
        start_iter = checkpoint.get('iteration', -1) + 1
        print(f"✅ Checkpoint loaded from iteration {start_iter - 1}. Replay buffer size: {len(replay_buffer)}")
        return start_iter, replay_buffer
    except Exception as e:
        print(f"Error loading checkpoint: {e}. Starting from scratch.")
        return 0, deque(maxlen=30000)


In [70]:
# ========== MAIN TRAINING LOOP ==========
def main():
    # Hyperparameters
    TOTAL_GAMES = 10000
    GAMES_PER_ITER = 48
    NUM_WORKERS = 6  # *** จำนวน Worker ที่จะใช้ *** (ปรับได้ระหว่าง 4-8)
    MCTS_SIMS = 80
    C_PUCT = 1.5
    INITIAL_LR = 0.001
    BATCH_SIZE_TRAIN = 256
    EPOCHS = 3
    REPLAY_SIZE = 40000
    CHECKPOINT_PATH = "/content/drive/MyDrive/AlphaZero_Backups"
    ENTROPY_WEIGHT = 0.01
    TEMPERATURE = 1.0
    NOISE_ALPHA = 0.3
    EVAL_GAMES = 40
    EVAL_INTERVAL = 3

    # Setup
    try:
        mp.set_start_method('spawn', force=True)
        print("Multiprocessing start method set to 'spawn'.")
    except RuntimeError:
        pass # Already set

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Main process training on {device}")

    model = MakNeebNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=INITIAL_LR, weight_decay=1e-5)
    replay = deque(maxlen=REPLAY_SIZE)

    start_iter, replay = load_checkpoint(model, optimizer, CHECKPOINT_PATH, device)

    mcts_params = {
        'mcts_sims': MCTS_SIMS,
        'c_puct': C_PUCT,
        'temperature': TEMPERATURE,
        'noise_alpha': NOISE_ALPHA
    }

    num_iters = TOTAL_GAMES // GAMES_PER_ITER
    for it in range(start_iter, num_iters):
        print(f"\n{'='*25} Iteration {it}/{num_iters} {'='*25}")

        # --- Self-play phase ---
        self_play_start = time.perf_counter()
        training_examples, game_stats = parallel_self_play(
            model, GAMES_PER_ITER, NUM_WORKERS, mcts_params
        )
        replay.extend(training_examples)
        self_play_time = time.perf_counter() - self_play_start

        # --- Logging ---
        player1_wins = sum(1 for s in game_stats if s['winner'] == 1)
        winrate = player1_wins / len(game_stats) if game_stats else 0
        avg_len = np.mean([s['game_length'] for s in game_stats]) if game_stats else 0
        print(f"\n📊 Self-Play Stats ({self_play_time:.1f}s):")
        print(f"  P1 Win Rate: {winrate:.2%} | Avg Len: {avg_len:.1f} | New Examples: {len(training_examples)} | Buffer: {len(replay)}/{REPLAY_SIZE}")

        # --- Training phase ---
        if len(replay) >= BATCH_SIZE_TRAIN:
            print("\n🧠 Training model...")
            training_start = time.perf_counter()
            enhanced_train_model(model, replay, optimizer, device, BATCH_SIZE_TRAIN, EPOCHS, ENTROPY_WEIGHT)
            print(f"  Training Time: {time.perf_counter() - training_start:.1f}s")
        else:
            print(f"\n⚠️ Replay buffer too small for training ({len(replay)}/{BATCH_SIZE_TRAIN}). Skipping.")

        # --- Evaluation phase ---
        if (it + 1) % EVAL_INTERVAL == 0 or it == num_iters - 1:
            enhanced_evaluate_model(model, device, EVAL_GAMES)

        save_checkpoint(it, model, optimizer, replay, CHECKPOINT_PATH)
        gc.collect()


In [None]:
if __name__ == "__main__":
    main()


Multiprocessing start method set to 'spawn'.
Main process training on cuda
Error loading checkpoint: [Errno 21] Is a directory: '/content/drive/MyDrive/AlphaZero_Backups'. Starting from scratch.

🚀 Launching 6 parallel workers for 48 games...
