## DQNでエージェントを構築
- Othelloの報酬設計と、環境は、utilsディレクトリ内に保存している

In [1]:
import os
import sys
import math
import copy
import random
import datetime
import inspect
from collections import deque
from dataclasses import dataclass
from typing import List, Optional, Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.notebook import tqdm
import torchinfo

sys.path.append("../")

# ローカル環境/報酬
from utils.mori.othello_env import OthelloEnv
from utils.mori.othello_game import OthelloGame
from utils.mori.othello_reward import ShapedReward

## DQNのネットワーク定義

In [2]:
class ResBlock(nn.Module):
    def __init__(self, ch: int, bn_eps: float = 1e-5, zero_init: bool = True):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(ch, eps=bn_eps)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(ch, eps=bn_eps)
        if zero_init:
            # 出力を初期は恒等写像に近づけて安定化
            nn.init.zeros_(self.bn2.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = F.relu(self.bn1(self.conv1(x)))
        h = self.bn2(self.conv2(h))
        return F.relu(x + h)

class DQN(nn.Module):
    def __init__(
        self,
        in_ch: int = 2,
        width: int = 32,
        num_res_blocks: int = 3,
        bn_eps: float = 1e-5,
    ):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, width, 3, padding=1, bias=False),
            nn.GroupNorm(1, width, eps=bn_eps),
            nn.ReLU(inplace=True),
        )
        self.res_blocks = nn.Sequential(
            *[ResBlock(width, bn_eps=bn_eps, zero_init=False) for _ in range(num_res_blocks)]
        )
        # 65アクション（0..63: 盤上, 64: パス）
        self.policy_head = nn.Sequential(
            nn.Conv2d(in_channels=width, out_channels=2, kernel_size=1, bias=False),  # (B,2,8,8)
            nn.ReLU(inplace=True),
        )
        self.logits_fc = nn.Linear(2 * 8 * 8, 65)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.stem(x)
        h = self.res_blocks(h)
        h = self.policy_head(h)
        h = h.view(h.size(0), -1)
        logits = self.logits_fc(h)  # (B,65)
        return logits

In [3]:
# アーキテクチャのテスト
dqn = DQN()
dummy_board = torch.zeros((1, 1, 8, 8))
dummy_player = torch.ones((1, 1, 8, 8))  # 手番 +1
dummy_input = torch.cat([dummy_board, dummy_player], dim=1)  # (1,2,8,8)
print(dqn(dummy_input).shape)
torchinfo.summary(dqn, (1, 2, 8, 8))

torch.Size([1, 65])


Layer (type:depth-idx)                   Output Shape              Param #
DQN                                      [1, 65]                   --
├─Sequential: 1-1                        [1, 32, 8, 8]             --
│    └─Conv2d: 2-1                       [1, 32, 8, 8]             576
│    └─GroupNorm: 2-2                    [1, 32, 8, 8]             64
│    └─ReLU: 2-3                         [1, 32, 8, 8]             --
├─Sequential: 1-2                        [1, 32, 8, 8]             --
│    └─ResBlock: 2-4                     [1, 32, 8, 8]             --
│    │    └─Conv2d: 3-1                  [1, 32, 8, 8]             9,216
│    │    └─BatchNorm2d: 3-2             [1, 32, 8, 8]             64
│    │    └─Conv2d: 3-3                  [1, 32, 8, 8]             9,216
│    │    └─BatchNorm2d: 3-4             [1, 32, 8, 8]             64
│    └─ResBlock: 2-5                     [1, 32, 8, 8]             --
│    │    └─Conv2d: 3-5                  [1, 32, 8, 8]             9,216
│    

## 対戦相手用の定義

In [4]:
class CriticNet(nn.Module):
    """ 盤面が入力されたらその評価値(-1~1)を出力するネットワーク """
    def __init__(
        self,
        in_ch: int = 2,
        width: int = 32,
        num_res_blocks: int = 3,
        bn_eps: float = 1e-5,
        head_hidden_size: int = 32,
        use_gap: bool = True,
        norm_head: str = "ln",
    ):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_ch, width, 3, padding=1, bias=False),
            nn.GroupNorm(1, width, eps=bn_eps),
            nn.ReLU(inplace=True),
        )
        self.res_blocks = nn.Sequential(
            *[ResBlock(width, bn_eps=bn_eps, zero_init=True) for _ in range(num_res_blocks)]
        )

        # value head: 1x1 conv → (norm) → ReLU
        self.value_conv = nn.Conv2d(width, 1, kernel_size=1, bias=False)

        if norm_head == "bn":
            self.value_norm = nn.GroupNorm(1, 1, eps=bn_eps)
        elif norm_head == "ln":
            # LayerNorm over (C,H,W) = (1,8,8) -> normalized_shape=(1,8,8)
            self.value_norm = nn.LayerNorm((1, 8, 8))
        elif norm_head == "gn":
            self.value_norm = nn.GroupNorm(1, 1)  # 1 group = LayerNorm的
        else:
            self.value_norm = nn.Identity()

        self.use_gap = use_gap
        if use_gap:
            in_fc = 1  # GAPで(1,)に
        else:
            in_fc = 8 * 8

        self.value_fc1 = nn.Linear(in_fc, head_hidden_size)
        self.value_fc2 = nn.Linear(head_hidden_size, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.stem(x)
        h = self.res_blocks(h)

        h = self.value_conv(h)  # (B,1,8,8)
        if isinstance(self.value_norm, nn.LayerNorm):
            h = self.value_norm(h)  # 形状そのまま
        else:
            h = self.value_norm(h)
        h = F.relu(h)

        if self.use_gap:
            h = h.mean(dim=(2, 3), keepdim=False)  # (B,1)
        else:
            h = h.view(h.size(0), -1)             # (B,64)

        h = F.relu(self.value_fc1(h))
        v = torch.tanh(self.value_fc2(h))         # [-1, 1]
        return v  # (B,1)

## DQNの学習

In [5]:
class ReplayBuffer:
    def __init__(self, memory_size):
        self.memory_size = memory_size
        self.memory = deque(maxlen=memory_size)

    def append(self, transition):
        self.memory.append(transition)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [6]:
class TrainDoubleDQN:
    def __init__(
        self,
        dqn: nn.Module,
        gamma: float = 0.99,
        lr: float = 1e-3,
        batch_size: int = 64,
        init_memory_size: int = 5000,
        memory_size: int = 50000,
        target_update_freq: int = 1000,
        tau: float = 0.005,
        num_episodes: int = 1000,
        max_games_per_episode: int = 5,
        train_freq: int = 1,
        gradient_steps: int = 1,
        learning_starts: int = 1000,
        epsilon_start: float = 1.0,
        epsilon_end: float = 0.05,
        epsilon_decay_steps: int = 30000,
        seed: int = 42,
        device: Optional[torch.device] = None,
        ReplayBufferCls=ReplayBuffer,
        rolling_window: int = 100,
        save_best_path: Optional[str] = None,
        pretrained_opp_path: Optional[str] = None,
        opponent_cycle: Optional[Tuple[str]] = None,
        progress_callback = None,
    ):
        assert ReplayBufferCls is not None, "ReplayBufferCls を渡してください"
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))

        # 再現性
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        # ネットワーク
        self.dqn = dqn.to(self.device)
        self.target_dqn = copy.deepcopy(dqn).to(self.device)
        self.target_dqn.eval()

        self.optimizer = optim.Adam(self.dqn.parameters(), lr=lr)
        self.loss_fn = nn.SmoothL1Loss()

        # 相手のモデル
        self.opp_model = None
        if pretrained_opp_path is not None:
            self.opp_model = CriticNet().to(self.device)
            self.opp_model.load_state_dict(torch.load(pretrained_opp_path))
            self.opp_model.eval()

        # HP
        self.gamma = gamma
        self.batch_size = int(batch_size)
        self.init_memory_size = int(init_memory_size)
        self.max_games_per_episode = int(max_games_per_episode)
        self.num_episodes = int(num_episodes)
        self.train_freq = max(1, int(train_freq))
        self.gradient_steps = max(1, int(gradient_steps))
        self.learning_starts = int(learning_starts)

        self.tau = float(tau)
        self.target_update_freq = int(target_update_freq)
        self._num_updates = 0
        self._num_env_steps = 0

        self.epsilon_start = float(epsilon_start)
        self.epsilon_end = float(epsilon_end)
        self.epsilon_decay_steps = max(1, int(epsilon_decay_steps))

        self.replay_buffer = ReplayBufferCls(int(memory_size))

        self.rewards: List[float] = []
        self.losses: List[float] = []
        self.rolling_window = int(rolling_window)
        self.best_score = -float("inf")
        self.best_state_dict = copy.deepcopy(self.dqn.state_dict())
        self.save_best_path = save_best_path
        self.progress_callback = progress_callback

        # 対戦モード管理
        default_cycle = ["self", "random"]
        if self.opp_model is not None:
            default_cycle.append("critic")

        if opponent_cycle is None:
            opponent_cycle = tuple(default_cycle)
        else:
            validated = []
            for mode in opponent_cycle:
                if mode == "critic" and self.opp_model is None:
                    raise ValueError("critic を指定していますが pretrained_opp_path がありません。")
                if mode not in ("self", "random", "critic"):
                    raise ValueError(f"未知の opponent '{mode}'")
                validated.append(mode)
            if not validated:
                raise ValueError("opponent_cycle が空です。")
            opponent_cycle = tuple(validated)

        self.opponent_cycle = opponent_cycle
        self.opponent_reward_logs: Dict[str, List[float]] = {m: [] for m in self.opponent_cycle}
        self.latest_opponent_reward: Dict[str, float] = {m: float("nan") for m in self.opponent_cycle}

        # 初期リプレイ収集
        self._init_replay_buffer()

    # --------- 公開 API ---------
    def train(self, return_best_model: bool = False):
        pbar = tqdm(total=self.num_episodes, desc="Train Double DQN")
        for ep in range(self.num_episodes):
            opponent = self.opponent_cycle[ep % len(self.opponent_cycle)]

            if opponent == "self":
                ep_reward = self._run_episodes_with_self(ep)
            elif opponent == "random":
                ep_reward = self._run_episode_with_random(ep)
            elif opponent == "critic":
                ep_reward = self._run_episode_with_pretrained_critic(ep)
            else:
                raise RuntimeError(f"未知の opponent '{opponent}'")

            self.rewards.append(ep_reward)
            self.opponent_reward_logs[opponent].append(ep_reward)
            self.latest_opponent_reward[opponent] = ep_reward

            if len(self.rewards) >= self.rolling_window:
                rolling_avg = float(np.mean(self.rewards[-self.rolling_window:]))
            else:
                rolling_avg = float(np.mean(self.rewards))

            if rolling_avg > self.best_score:
                self.best_score = rolling_avg
                self.best_state_dict = copy.deepcopy(self.dqn.state_dict())
                if self.save_best_path is not None:
                    torch.save(self.best_state_dict, self.save_best_path)

            last_loss = self.losses[-1] if self.losses else float("nan")
            pbar.set_postfix_str(
                f"EpR {ep_reward*64:.2f}, Roll@{self.rolling_window}: {rolling_avg*64:.2f}, Best: {self.best_score:.2f}, Loss: {last_loss:.3f}, ε: {self._epsilon_by_step(self._num_env_steps):.3f}"
            )
            pbar.update(1)

            if self.progress_callback is not None:
                info = {
                    "opponent": opponent,
                    "latest_rewards": {
                        mode: (logs[-1] if logs else float("nan"))
                        for mode, logs in self.opponent_reward_logs.items()
                    },
                    "history": {mode: list(logs) for mode, logs in self.opponent_reward_logs.items()},
                }
                params = inspect.signature(self.progress_callback).parameters
                if len(params) <= 4:
                    self.progress_callback(ep, ep_reward, rolling_avg, last_loss)
                else:
                    self.progress_callback(ep, ep_reward, rolling_avg, last_loss, info)

        pbar.close()

        metrics = {"rewards": self.rewards, "losses": self.losses}
        if return_best_model:
            best_model = self._clone_model_with_state(self.best_state_dict)
            return metrics, best_model
        return metrics

    def get_best_model(self) -> nn.Module:
        """ベストモデルを返す"""
        return self._clone_model_with_state(self.best_state_dict)

    def _clone_model_with_state(self, state_dict) -> nn.Module:
        """モデルを複製して、指定された状態を設定する"""
        model = copy.deepcopy(self.dqn).to(self.device)
        model.load_state_dict(state_dict)
        model.eval()
        return model

    def _run_episodes_with_self(self, episode_idx: int) -> float:
        """ エピソードを複数回実行する """
        env = OthelloEnv()
        # 手番をランダム化（約 50% で白番スタートに切り替え）
        if random.random() < 0.5:
            env.step(64)
        total_reward = 0.0
        games = 0
        done = False

        while not done and games < self.max_games_per_episode:
            # ε-greedy 選択
            epsilon = self._epsilon_by_step(self._num_env_steps)
            player = env.player
            state = env.get_state()
            action = self._select_action_by_epsilon_greedy(env, epsilon)

            # シェーピング付き報酬
            reward_fn = ShapedReward(player)
            next_state, reward, done, _ = env.step(action, reward_fn=reward_fn)
            next_player = env.player

            next_legal_actions = self._legal_actions_from_board(next_state[0], next_player)

            # 状態遷移を保存
            # 終局の場合は３回状態遷移を保存（通常時は1回）
            if done:
                num_save = 3
            else:
                num_save = 1
            for _ in range(num_save):
                self._store_transition(
                    board=state[0],
                    action=action,
                    reward=reward,
                    next_board=next_state[0],
                    done=done,
                    player=player,
                    next_player=next_player,
                    next_legal_actions=next_legal_actions,
                )
            self._num_env_steps += 1

            # 学習トリガ
            if (self._num_env_steps % self.train_freq == 0) and (len(self.replay_buffer) >= max(self.batch_size, self.learning_starts)):
                for _ in range(self.gradient_steps):
                    loss = self._update_dqn_double()
                    if not math.isnan(loss):
                        self.losses.append(loss)

            # ターゲット更新
            self._maybe_update_target()

            total_reward += float(reward)

            if done:
                games += 1

        return float(total_reward)

    def _run_episode_with_random(self, episode_idx: int) -> float:
        """ランダム相手と対戦しながら 1 〜 max_games_per_episode 局プレイして学習する"""
        total_reward = 0.0
        games = 0

        while games < self.max_games_per_episode:
            env = OthelloEnv()

            # 開始手番をランダム化（約 50% で白番スタートに切り替え）
            if np.random.rand() < 0.5:
                env.step(64)  # パスで手番を反転
            my_color = env.player  # 現在の手番が自分の色

            done = False
            while not done:
                player = env.player
                state = env.get_state()

                # 自分の手番の場合のみε-greedy
                if player == my_color:
                    epsilon = self._epsilon_by_step(self._num_env_steps)
                    action = self._select_action_by_epsilon_greedy(env, epsilon)
                else:
                    action = random.choice(env.legal_actions())

                # 自分視点の報酬関数
                reward_fn = ShapedReward(my_color)
                next_state, reward, done, _ = env.step(action, reward_fn=reward_fn)
                next_player = env.player
                next_legal_actions = self._legal_actions_from_board(next_state[0], next_player)

                # 自分の手番の場合のみ状態遷移を保存
                if player == my_color:
                    # 状態遷移を保存
                    # 終局の場合は３回状態遷移を保存（通常時は1回）
                    if done:
                        num_save = 3
                    else:
                        num_save = 1
                    for _ in range(num_save):
                        self._store_transition(
                            board=state[0],
                            action=action,
                            reward=reward,
                            next_board=next_state[0],
                            done=done,
                            player=player,
                            next_player=next_player,
                            next_legal_actions=next_legal_actions,
                        )
                    self._num_env_steps += 1

                # 学習トリガ
                if (
                    (self._num_env_steps % self.train_freq == 0)
                    and (len(self.replay_buffer) >= max(self.batch_size, self.learning_starts))
                ):
                    for _ in range(self.gradient_steps):
                        loss = self._update_dqn_double()
                        if not math.isnan(loss):
                            self.losses.append(loss)

                # ターゲット更新
                self._maybe_update_target()

                total_reward += float(reward)

            games += 1

        return float(total_reward)

    def _run_episode_with_pretrained_critic(
        self,
        episode_idx: int,
    ) -> float:
        """事前学習済み Critic を相手に複数局プレイして学習する"""
        if not hasattr(self, "opp_model") or self.opp_model is None:
            raise ValueError("pretrained critic がロードされていません。")

        self.opp_model.eval()

        total_reward = 0.0
        games = 0

        while games < self.max_games_per_episode:
            env = OthelloEnv()

            # 50% の確率で白番スタートにして多様性を確保
            if random.random() < 0.5:
                env.step(64)
            my_color = env.player

            done = False
            while not done:
                player = env.player
                state = env.get_state()

                if player == my_color:
                    epsilon = self._epsilon_by_step(self._num_env_steps)
                    action = self._select_action_by_epsilon_greedy(env, epsilon)
                else:
                    action = self._select_action_by_critic(env, self.opp_model)

                reward_fn = ShapedReward(my_color)
                next_state, reward, done, _ = env.step(action, reward_fn=reward_fn)
                next_player = env.player
                next_legal_actions = self._legal_actions_from_board(next_state[0], next_player)

                if player == my_color:
                    # 状態遷移を保存
                    # 終局の場合は３回状態遷移を保存（通常時は1回）
                    if done:
                        num_save = 3
                    else:
                        num_save = 1
                    for _ in range(num_save):
                        self._store_transition(
                            board=state[0],
                            action=action,
                            reward=reward,
                            next_board=next_state[0],
                            done=done,
                            player=player,
                            next_player=next_player,
                            next_legal_actions=next_legal_actions,
                        )
                    self._num_env_steps += 1

                if (
                    (self._num_env_steps % self.train_freq == 0)
                    and (len(self.replay_buffer) >= max(self.batch_size, self.learning_starts))
                ):
                    for _ in range(self.gradient_steps):
                        loss = self._update_dqn_double()
                        if not math.isnan(loss):
                            self.losses.append(loss)

                self._maybe_update_target()

                total_reward += float(reward)

            games += 1

        return float(total_reward)

    def _select_action_by_critic(
        self,
        env: OthelloEnv,
        critic: nn.Module,
    ) -> int:
        """Critic の評価値に基づき相手の行動を決定"""
        legal_actions = env.legal_actions()
        values = []

        # 現在局面をコピーして各手をシミュレート
        for action in legal_actions:
            sim_game = env.game.clone()
            sim_env = OthelloEnv()
            sim_env.game = sim_game
            sim_env.player = env.player

            # アクション適用
            try:
                sim_env.step(action)
            except Exception:
                values.append(-float("inf"))
                continue

            sim_state = sim_env.get_state()
            input_tensor = torch.as_tensor(sim_state, dtype=torch.float32, device=self.device).unsqueeze(0)

            # 現在の盤面での相手の視点の評価値を計算
            with torch.no_grad():
                v = critic(input_tensor).item()

            # Critic は現在手番視点の値を返すと仮定し、そのまま最大化
            values.append(v)

        values = np.array(values, dtype=np.float32)

        best_idx = int(values.argmax())
        return int(legal_actions[best_idx])


    def _update_dqn_double(self) -> float:
        batch = self.replay_buffer.sample(self.batch_size)

        # (B,2,8,8)
        board = torch.stack([self._to_input(b['board'], b['player']) for b in batch]).to(self.device)
        next_board = torch.stack([self._to_input(b['next_board'], b['next_player']) for b in batch]).to(self.device)

        action = torch.tensor([b['action'] for b in batch], dtype=torch.int64, device=self.device)
        reward = torch.tensor([b['reward'] for b in batch], dtype=torch.float32, device=self.device)
        done = torch.tensor([b['done'] for b in batch], dtype=torch.float32, device=self.device)
        next_legal_actions_list = [b['next_legal_actions'] for b in batch]

        # Q(s,a)
        q_all = self.dqn(board)                                 # (B,65)
        q_sa = q_all.gather(1, action.unsqueeze(1)).squeeze(1)  # (B,)

        with torch.no_grad():
            # 非合法手マスク
            next_masks = self._build_masks_from_indices(next_legal_actions_list, fill_value=-1e9)
            q_next_online = self.dqn(next_board) + next_masks
            next_actions_online = q_next_online.argmax(dim=1)  # (B,)

            q_next_target = self.target_dqn(next_board)
            next_q = q_next_target.gather(1, next_actions_online.unsqueeze(1)).squeeze(1)

            target = reward + self.gamma * next_q * (1.0 - done)

        loss = self.loss_fn(q_sa, target)
        if torch.isnan(loss):
            return float("nan")

        self.optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.dqn.parameters(), max_norm=1.0)
        self.optimizer.step()

        self._num_updates += 1
        return float(loss.item())

    def _select_action_by_epsilon_greedy(self, env: OthelloEnv, epsilon: float) -> int:
        """ ε-greedy で行動を選択する """
        legal_actions = env.legal_actions()
        if random.random() < epsilon:
            return random.choice(legal_actions)
        return self._select_action_by_greedy(env)

    def _select_action_by_greedy(self, env: OthelloEnv) -> int:
        """ greedy で行動を選択する """
        legal_actions = env.legal_actions()
        board_tensor = self._to_input(env.get_state(), env.player).unsqueeze(0).to(self.device)  # (1,2,8,8)
        with torch.no_grad():
            q_all = self.dqn(board_tensor).squeeze(0)  # (65,)
            mask = torch.full((65,), -1e9, device=self.device)
            for a in legal_actions:
                mask[a] = 0.0
            q_masked = q_all + mask
            action = int(q_masked.argmax().item())
        return action

    def _init_replay_buffer(self):
        """
        リプレイバッファを初期化する。
        ターミナル遷移と非ターミナル遷移が、それぞれの目標数に達するまでゲームをプレイし続ける。
        """
        target = min(self.init_memory_size, self.replay_buffer.memory_size)
        terminal_quota = target // 3
        non_terminal_quota = target - terminal_quota

        num_terminal = 0
        num_non_terminal = 0

        pbar = tqdm(total=target, desc='Init replay buffer')

        # ターミナルと非ターミナルの両方の枠が埋まるまでループ
        while num_terminal < terminal_quota or num_non_terminal < non_terminal_quota:
            env = OthelloEnv()
            done = False
            episode = []

            # 1局を終局までプレイ
            while not done:
                player = env.player
                state = env.get_state()
                action = random.choice(env.legal_actions())
                reward_fn = ShapedReward(player)
                next_state, reward, done, _ = env.step(action, reward_fn=reward_fn)
                next_player = env.player
                next_legal_actions = self._legal_actions_from_board(next_state[0], next_player)

                episode.append({
                    'board': state[0], 'action': action, 'reward': reward,
                    'next_board': next_state[0], 'done': done, 'player': player,
                    'next_player': next_player, 'next_legal_actions': next_legal_actions,
                })

            # ターミナル遷移を追加（最後の遷移）
            if num_terminal < terminal_quota and episode:
                t = episode[-1]
                self._store_transition(
                    board=t['board'], action=t['action'], reward=t['reward'],
                    next_board=t['next_board'], done=t['done'],
                    player=t['player'], next_player=t['next_player'],
                    next_legal_actions=t['next_legal_actions'],
                )
                num_terminal += 1
                pbar.update(1)
                pbar.set_postfix({"terminal": num_terminal, "non_terminal": num_non_terminal})

            # 非ターミナル遷移を追加（最後の遷移以外）
            for t in episode[:-1]:
                if num_non_terminal >= non_terminal_quota:
                    break
                self._store_transition(
                    board=t['board'], action=t['action'], reward=t['reward'],
                    next_board=t['next_board'], done=False,  # 非ターミナルとして扱う
                    player=t['player'], next_player=t['next_player'],
                    next_legal_actions=t['next_legal_actions'],
                )
                num_non_terminal += 1

                pbar.set_postfix({"terminal": num_terminal, "non_terminal": num_non_terminal})
                pbar.update(1)

        pbar.close()

    def _maybe_update_target(self):
        """ ターゲットネットワークを更新する """
        if self.tau and self.tau > 0.0:
            with torch.no_grad():
                for tp, p in zip(self.target_dqn.parameters(), self.dqn.parameters()):
                    tp.data.mul_(1.0 - self.tau).add_(self.tau * p.data)
        else:
            if (self._num_updates % max(1, self.target_update_freq)) == 0 and self._num_updates > 0:
                self.target_dqn.load_state_dict(self.dqn.state_dict())

    def _epsilon_by_step(self, step: int) -> float:
        """ epsilon-greedy 用の ε を計算する """
        if step >= self.epsilon_decay_steps:
            return self.epsilon_end
        span = self.epsilon_start - self.epsilon_end
        return self.epsilon_start - span * (step / self.epsilon_decay_steps)

    def _build_masks_from_indices(self, batch_next_legal_actions: List[List[int]], fill_value: float = -1e9):
        """ バッチの各状態の合法手をマスクする """
        B = len(batch_next_legal_actions)
        masks = torch.full((B, 65), float(fill_value), device=self.device)
        for i, acts in enumerate(batch_next_legal_actions):
            for a in acts:
                masks[i, a] = 0.0
        return masks

    def _to_input(self, board_like, player_scalar: int) -> torch.Tensor:
        """ ボードを入力形式に変換する """
        t = torch.as_tensor(board_like, dtype=torch.float32)
        if t.dim() == 3 and t.shape == (2, 8, 8):
            t = t[0:1]
        elif t.dim() == 2 and t.shape == (8, 8):
            t = t.unsqueeze(0)
        elif t.dim() == 3 and t.shape == (1, 8, 8):
            pass
        else:
            t = t.reshape(1, 8, 8)
        player_plane = torch.full_like(t, float(player_scalar))
        return torch.cat([t, player_plane], dim=0)

    def _legal_actions_from_board(self, board_np: np.ndarray, player: int) -> List[int]:
                # OthelloGame を新規に作成し、numpy 配列を設定
        g = OthelloGame()
        # board_np は float32（-1,0,1）なので int8 に変換
        if board_np.ndim == 3 and board_np.shape == (2, 8, 8):
            board_np = board_np[0]
        g.board = np.asarray(board_np, dtype=np.int8)
        g.player = int(player)
        moves = g.legal_moves(g.player)
        return [64] if not moves else [r * 8 + c for r, c in moves]

    def _store_transition(
        self,
        board,
        action: int,
        reward: float,
        next_board,
        done: bool,
        player: int,
        next_player: int,
        next_legal_actions: List[int],
    ):
        """ 状態遷移を保存する """
        legal_actions = self._legal_actions_from_board(np.array(board, dtype=np.float32), int(player))
        transition = {
            "board": np.array(board, dtype=np.float32),
            "action": int(action),
            "reward": float(reward),
            "next_board": np.array(next_board, dtype=np.float32),
            "done": bool(done),
            "player": int(player),
            "next_player": int(next_player),
            "legal_actions": legal_actions,
            "next_legal_actions": next_legal_actions,
        }
        self.replay_buffer.append(transition)


In [7]:
reward_out = widgets.Output()
loss_out = widgets.Output()
opponent_out = widgets.Output()
display(widgets.VBox([reward_out, loss_out, opponent_out]))

rewards_history = []
losses_history = []
opponent_histories = {}

def live_callback(ep, reward, rolling_avg, latest_loss, info=None):
    rewards_history.append(reward)
    if not math.isnan(latest_loss):
        losses_history.append(latest_loss)

    # 全体報酬プロット
    with reward_out:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(6, 3))
        ax.plot(rewards_history, label="Episode Reward")
        # 0の基準線
        ax.axhline(0.0, color="red", linestyle="--", linewidth=1, alpha=0.8)
        if len(rewards_history) >= 20:
            rolling = np.convolve(rewards_history, np.ones(20) / 20, mode="valid")
            ax.plot(range(19, 19 + len(rolling)), rolling, label="MA@20", color="tab:orange")
        ax.set_title(f"Episode {ep + 1}")
        ax.grid(True)
        ax.legend()
        plt.show()

    # 損失プロット
    with loss_out:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(6, 3))
        if losses_history:
            ax.plot(losses_history, label="Training Loss", color="tab:red")
        ax.grid(True)
        ax.legend()
        plt.show()

    # 対戦モードごとの報酬プロット
    if info is not None:
        opponent_histories.clear()
        for mode, hist in info.get("history", {}).items():
            opponent_histories[mode] = list(hist)

        with opponent_out:
            clear_output(wait=True)
            fig, ax = plt.subplots(figsize=(6, 3))
            for mode, hist in opponent_histories.items():
                if hist:
                    ax.plot(hist, label=f"{mode} vs reward")
            # 0の基準線
            ax.axhline(0.0, color="red", linestyle="--", linewidth=1, alpha=0.8)
            ax.set_title(f"Latest opponent: {info.get('opponent', 'unknown')}")
            ax.grid(True)
            ax.legend()
            plt.show()

VBox(children=(Output(), Output(), Output()))

In [None]:
save_dir = "models"
progress_callback = live_callback
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"Device: {device}")

dqn = DQN().to(device)
trainer = TrainDoubleDQN(
    dqn=dqn,
    device=device,
    ReplayBufferCls=ReplayBuffer,
    num_episodes=2e3,
    batch_size=256,
    gamma=0.99,
    lr=5e-4,
    target_update_freq=500,
    epsilon_start=1.0,
    epsilon_end=0.05,
    epsilon_decay_steps=1e5,
    max_games_per_episode=1,
    init_memory_size=5e3,
    memory_size=1e4,
    learning_starts=1e3,
    rolling_window=20,
    save_best_path=os.path.join(save_dir, "best.pth"),
    pretrained_opp_path = os.path.join(save_dir, "critic", "pretrained_criticnet.pt"),
    progress_callback=live_callback,
)

metrics, best_model = trainer.train(return_best_model=True)

# 保存
os.makedirs(save_dir, exist_ok=True)
ts = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_path = os.path.join(save_dir, f"best_{ts}.pth")
torch.save(best_model.state_dict(), save_path)

print(f"Best model saved to {save_path}")

Device: mps


Init replay buffer:   0%|          | 0/5000 [00:00<?, ?it/s]

Train Double DQN:   0%|          | 0/2000 [00:00<?, ?it/s]