In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from collections import deque
import time
import os
import math
from tqdm import tqdm

# --- 保存先ディレクトリの設定（環境に合わせて変更） ---
save_dir = r"学習データ"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# =========================
# 1. サイコロクラス（変更なし）
# =========================
class Dice:
    def __init__(self, pos):
        top = random.randint(1, 6)
        bottom = 7 - top
        candidates = [n for n in range(1, 7) if n not in [top, bottom]]
        left = random.choice(candidates)
        right = 7 - left
        if left in candidates:
            candidates.remove(left)
        if not candidates:
            all_faces = set(range(1, 7))
            used_faces = {top, bottom, left, right}
            candidates = list(all_faces - used_faces)
            if not candidates:
                raise ValueError("Cannot determine front/back faces")
        front = random.choice(candidates)
        back = 7 - front
        self.orientation = (top, bottom, left, right, front, back)
        self.pos = pos  # (row, col)
    @property
    def top(self):
        return self.orientation[0]
    def roll(self, direction):
        top, bottom, left, right, front, back = self.orientation
        if direction == 'up':
            self.orientation = (front, back, left, right, bottom, top)
        elif direction == 'down':
            self.orientation = (back, front, left, right, top, bottom)
        elif direction == 'right':
            self.orientation = (left, right, bottom, top, front, back)
        elif direction == 'left':
            self.orientation = (right, left, top, bottom, front, back)
        else:
            raise ValueError("無効な方向です")

# ================================
# 2. ４人対戦環境クラス（状態に現在のプレイヤー番号を追加）
# ================================
class PsychoDiceWarsEnv4:
    def __init__(self, board_size=6, max_turns=20):
        self.board_size = board_size
        self.max_turns = max_turns   # 1ゲーム＝20セクションとする
        self.reset()

    def reset(self):
        # 盤面は 6x6、各セルの値は 0（未塗装）、1～4（各プレイヤーの色）
        self.board = np.zeros((self.board_size, self.board_size), dtype=int)
        self.turn_count = 0
        self.current_player = 1  # プレイヤー1からスタート
        self.dice = {}           # プレイヤー番号 1～4 のダイス管理
        self.moves_remaining = {}
        self.done = False
        # ４人のダイスを配置するため、重複しない位置をランダムに選ぶ
        positions = set()
        while len(positions) < 4:
            pos = (random.randint(0, self.board_size - 1), random.randint(0, self.board_size - 1))
            positions.add(pos)
        positions = list(positions)
        for player in range(1, 5):
            self.dice[player] = Dice(positions[player - 1])
            self.board[self.dice[player].pos] = player
            self.moves_remaining[player] = self.dice[player].top
        return self.get_state()

    def get_state(self):
        # 状態：盤面 (36) + プレイヤー1ダイス位置 (2) + 他3人位置 (6) +
        # 自分の残移動回数 (1) + 自分のダイス上面 (1) + 他3人ダイス上面 (3) +
        # ターン正規化 (1) + 現在のプレイヤー番号 (1) = 51 要素
        board_flat = self.board.flatten().astype(np.float32)
        pos_self = np.array(self.dice[1].pos, dtype=np.float32) if self.dice.get(1) else np.array([-1,-1],dtype=np.float32)
        pos_opponents = []
        for p in [2,3,4]:
            if self.dice.get(p):
                pos_opponents.extend(self.dice[p].pos)
            else:
                pos_opponents.extend([-1,-1])
        pos_opponents = np.array(pos_opponents, dtype=np.float32)
        moves = np.array([self.moves_remaining.get(1,0)], dtype=np.float32)
        dice_top_self = np.array([self.dice[1].top], dtype=np.float32) if self.dice.get(1) else np.array([0],dtype=np.float32)
        dice_top_opponents = []
        for p in [2,3,4]:
            dice_top_opponents.append(self.dice[p].top if self.dice.get(p) else 0)
        dice_top_opponents = np.array(dice_top_opponents, dtype=np.float32)
        turn_norm = np.array([self.turn_count / self.max_turns], dtype=np.float32)
        current_player_arr = np.array([self.current_player], dtype=np.float32)
        # 結合すると 36+2+6+1+1+3+1+1 = 51 次元
        state = np.concatenate([board_flat, pos_self, pos_opponents, moves, dice_top_self, dice_top_opponents, turn_norm, current_player_arr])
        if state.shape[0] != 51:
            state = np.pad(state, (0, 51 - state.shape[0]), 'constant', constant_values=0.0)[:51]
        state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
        return state.astype(np.float32)

    def print_board(self):
        # 表示は 0:「・」、1:「〇」、2:「△」、3:「◇」、4:「☆」
        symbol = {0:"・", 1:"〇", 2:"△", 3:"◇", 4:"☆"}
        print("-" * (self.board_size * 3))
        for r in range(self.board_size):
            row_str = ""
            for c in range(self.board_size):
                row_str += symbol.get(self.board[r,c], "?") + "  "
            print(row_str)
        print("-" * (self.board_size * 3))
        print(f"Turn: {self.turn_count}/{self.max_turns} | Current Player: {self.current_player} | Moves Left: {self.moves_remaining.get(self.current_player,0)}")

    def valid_move(self, pos, direction):
        r, c = pos
        dr, dc = 0, 0
        if direction == 'up':    dr = -1
        elif direction == 'down':  dr = 1
        elif direction == 'left':  dc = -1
        elif direction == 'right': dc = 1
        next_r, next_c = r+dr, c+dc
        return (0 <= next_r < self.board_size and 0 <= next_c < self.board_size), (next_r, next_c)

    def step(self, action):
        actions = ['up', 'down', 'left', 'right']
        chosen_dir = actions[action]
        player = self.current_player
        dice = self.dice[player]
        orig_pos = dice.pos
        reward = 0.0
        info = {'battle_result':'none', 'turn_ended':False, 'final_score':None}
        try:
            dice.roll(chosen_dir)
        except ValueError as e:
            print(f"Dice roll error: {e}")
            reward -= 10
            self.moves_remaining[player] = 0
            info['error'] = 'Invalid roll direction'
        valid, next_pos = self.valid_move(orig_pos, chosen_dir)
        if valid:
            target_owner = self.board[next_pos]
            dice.pos = next_pos
            self.board[dice.pos] = player
            move_reward = 0.1 if target_owner==0 else (0.3 if target_owner != player else 0.05)
            reward += move_reward
        else:
            reward -= 0.5
        self.moves_remaining[player] -= 1
        turn_end = (self.moves_remaining[player] <= 0)
        if turn_end:
            info['turn_ended'] = True
            # キル判定：同じセルに他プレイヤーのダイスがあればキル処理
            for opp in range(1,5):
                if opp==player: continue
                if self.dice.get(opp) and dice.pos == self.dice[opp].pos:
                    try:
                        self._kill(opp, dice.pos)
                        reward += 5.0
                        info['battle_result'] = 'kill'
                    except Exception as e:
                        print(f"Kill error: {e}")
                        reward -= 10
                        info['error'] = 'Kill failed'
            if dice.top == 6:
                reward += 0.5
            for opp in range(1,5):
                if opp==player or not self.dice.get(opp): continue
                dx = abs(dice.pos[0] - self.dice[opp].pos[0])
                dy = abs(dice.pos[1] - self.dice[opp].pos[1])
                if dx+dy == 1:
                    reward -= 1.0
                    break
            next_player = (player % 4) + 1
            self.current_player = next_player
            if next_player==1:
                self.turn_count += 1
            if self.dice.get(next_player):
                self.moves_remaining[next_player] = self.dice[next_player].top
            else:
                self.moves_remaining[next_player] = 0
        if (not self.done) and (self.turn_count >= self.max_turns):
            self.done = True
            scores = {p: np.sum(self.board==p) for p in range(1,5)}
            info['final_score'] = scores
            best = max(scores, key=scores.get)
            if scores[1] > max(scores[p] for p in [2,3,4]):
                final_reward = (scores[1] - max(scores[p] for p in [2,3,4])) * 0.1 + 15.0
            elif scores[1] < max(scores[p] for p in [2,3,4]):
                final_reward = (scores[1] - max(scores[p] for p in [2,3,4])) * 0.1 - 15.0
            else:
                final_reward = 0.0
            reward += final_reward
        next_state = self.get_state()
        # 今回は行動したプレイヤーに関わらず、報酬・状態をそのまま学習対象とする
        if np.isnan(reward) or np.isinf(reward):
            reward = 0.0
        if np.isnan(next_state).any() or np.isinf(next_state).any():
            next_state = np.nan_to_num(next_state, nan=0.0, posinf=0.0, neginf=0.0)
        return next_state.astype(np.float32), float(reward), self.done, info

    def _kill(self, player_to_kill, killer_pos):
        killer_color = self.current_player  # キルを行ったプレイヤーの色
        r_kill, c_kill = killer_pos
        for dr in [-1,0,1]:
            for dc in [-1,0,1]:
                nr, nc = r_kill+dr, c_kill+dc
                if 0<= nr < self.board_size and 0<= nc < self.board_size:
                    self.board[nr,nc] = killer_color
        respawn_attempts = 0
        max_attempts = self.board_size * self.board_size * 2
        killer_dice_pos = self.dice[self.current_player].pos if self.dice.get(self.current_player) else None
        new_pos = None
        while respawn_attempts < max_attempts:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            candidate = (new_r, new_c)
            if candidate != killer_dice_pos:
                new_pos = candidate
                break
            respawn_attempts += 1
        if new_pos is None:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            new_pos = (new_r, new_c)
        self.dice[player_to_kill] = Dice(new_pos)
        self.board[new_pos] = player_to_kill
        self.moves_remaining[player_to_kill] = self.dice[player_to_kill].top

# ========================================
# 3. DQNネットワークと統合エージェント（全プレイヤー共通のモデル：入力次元 51）
# ========================================
class DQN(nn.Module):
    def __init__(self, input_dim=51, output_dim=4):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.float()
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        return self.net(x)

class DQNAgent:
    def __init__(self, input_dim=51, output_dim=4,
                 lr=5e-5, gamma=0.99, epsilon=1.0, epsilon_min=0.05,
                 epsilon_decay=0.999995, buffer_size=1_000_000, batch_size=256,
                 target_update_freq=1000, lr_scheduler_gamma=0.9998):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = DQN(input_dim, output_dim).to(self.device)
        self.target_model = DQN(input_dim, output_dim).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, amsgrad=True)
        self.scheduler = ExponentialLR(self.optimizer, gamma=lr_scheduler_gamma)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.memory = deque(maxlen=buffer_size)  # 十分大きなリプレイバッファ
        self.target_update_freq = target_update_freq
        self.train_step_counter = 0

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.output_dim - 1)
        else:
            if not isinstance(state, np.ndarray):
                state = np.array(state, dtype=np.float32)
            if np.isnan(state).any() or np.isinf(state).any():
                state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
            try:
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                self.model.eval()
                with torch.no_grad():
                    q_values = self.model(state_tensor)
                self.model.train()
                if torch.isnan(q_values).any() or torch.isinf(q_values).any():
                    return random.randint(0, self.output_dim - 1)
                return int(torch.argmax(q_values).item())
            except Exception as e:
                return random.randint(0, self.output_dim - 1)

    def store_transition(self, state, action, reward, next_state, done):
        state = np.nan_to_num(np.array(state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        next_state = np.nan_to_num(np.array(next_state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        reward = np.float32(np.nan_to_num(reward, nan=0.0, posinf=0.0, neginf=0.0))
        self.memory.append((state, int(action), reward, next_state, bool(done)))

    def train_step(self):
        if len(self.memory) < self.batch_size * 10:
            return 0.0
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.BoolTensor(dones).unsqueeze(1).to(self.device)

        self.model.eval()
        self.target_model.eval()
        with torch.no_grad():
            next_q_online = self.model(next_states)
            best_next_actions = next_q_online.argmax(dim=1, keepdim=True)
            next_q_target = self.target_model(next_states)
            next_q_value = next_q_target.gather(1, best_next_actions)
            target_q = rewards + self.gamma * next_q_value * (~dones)
        self.model.train()
        current_q = self.model(states).gather(1, actions)
        loss = nn.SmoothL1Loss()(current_q, target_q.detach())
        if torch.isnan(loss) or torch.isinf(loss):
            return 0.0
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_step_counter += 1
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        if self.train_step_counter % self.target_update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        return loss.item()

    def update_scheduler(self):
        self.scheduler.step()

    def save_model(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        if os.path.exists(path):
            state_dict = torch.load(path, map_location=self.device)
            self.model.load_state_dict(state_dict)
            self.target_model.load_state_dict(self.model.state_dict())
            print(f"Model loaded from {path}")
        else:
            print(f"Model file not found at {path}. Starting from scratch.")

# ===================================
# 4. 統合学習ループ（全プレイヤーが学習対象）
# ===================================
if __name__ == '__main__':
    NUM_EPISODES = 500
    BOARD_SIZE = 6
    MAX_TURNS = 20         # 1ゲーム＝20セクション
    INPUT_DIM = 51         # (36+2+6+1+1+3+1+1)
    OUTPUT_DIM = 4
    LR = 1e-5
    GAMMA = 0.99
    EPSILON_START = 1.0
    EPSILON_MIN = 0.05
    TOTAL_DECAY_FRAMES = NUM_EPISODES * MAX_TURNS * 0.75
    EPSILON_DECAY_RATE = math.exp(math.log(EPSILON_MIN/EPSILON_START)/TOTAL_DECAY_FRAMES)
    BUFFER_SIZE = 1_000_000   # 十分大きなリプレイバッファ
    BATCH_SIZE = 256
    TARGET_UPDATE_FREQ = 1000
    LR_SCHEDULER_GAMMA = 0.9998
    TRAIN_START_FRAMES = BATCH_SIZE * 10
    TRAIN_FREQ = 1
    LOG_INTERVAL = 100
    SAVE_INTERVAL = 5000
    LOAD_EXISTING_MODEL = False
    MODEL_SAVE_PATH = os.path.join(save_dir, "dqn_multiagent_final.pth")

    env = PsychoDiceWarsEnv4(board_size=BOARD_SIZE, max_turns=MAX_TURNS)
    # 全プレイヤー共通のグローバルモデルを使用するため、１つのエージェント（DQNAgent）を生成する
    agent = DQNAgent(input_dim=INPUT_DIM, output_dim=OUTPUT_DIM,
                     lr=LR, gamma=GAMMA, epsilon=EPSILON_START, epsilon_min=EPSILON_MIN,
                     epsilon_decay=EPSILON_DECAY_RATE, buffer_size=BUFFER_SIZE, batch_size=BATCH_SIZE,
                     target_update_freq=TARGET_UPDATE_FREQ, lr_scheduler_gamma=LR_SCHEDULER_GAMMA)
    if LOAD_EXISTING_MODEL:
        agent.load_model(MODEL_SAVE_PATH)

    log_path = os.path.join(save_dir, "training_log_multiagent.txt")
    log_file = open(log_path, "w", encoding="utf-8")
    start_time = time.time()
    total_env_steps = 0
    recent_rewards = deque(maxlen=LOG_INTERVAL)
    recent_losses = deque(maxlen=LOG_INTERVAL * MAX_TURNS)

    pbar = tqdm(range(NUM_EPISODES), unit="ep", desc="Multi-agent Training...", ncols=120)
    try:
        for ep in pbar:
            state = env.reset()
            done = False
            episode_reward = 0.0
            while not done:
                # ここでは、どのプレイヤーのターンでも、その時点の状態（状態にプレイヤー番号が含まれている）を
                # グローバルエージェントが判断して行動する
                action = agent.choose_action(state)
                next_state, reward, done, info = env.step(action)
                total_env_steps += 1
                episode_reward += reward
                # すべてのステップをリプレイバッファに保存
                agent.store_transition(state, action, reward, next_state, done)
                if total_env_steps > TRAIN_START_FRAMES and total_env_steps % TRAIN_FREQ == 0:
                    loss = agent.train_step()
                    if loss > 0:
                        recent_losses.append(loss)
                state = next_state
            agent.update_scheduler()
            recent_rewards.append(episode_reward)
            if (ep+1) % LOG_INTERVAL == 0:
                avg_reward = np.mean(recent_rewards)
                avg_loss = np.mean(recent_losses) if recent_losses else 0
                elapsed = time.time()-start_time
                log_msg = f"Ep {ep+1}/{NUM_EPISODES} | Avg Reward: {avg_reward:.3f} | Avg Loss: {avg_loss:.5f} | Total Steps: {total_env_steps} | Elapsed: {elapsed:.1f}s\n"
                print(log_msg)
                log_file.write(log_msg)
                log_file.flush()
            if (ep+1) % SAVE_INTERVAL == 0:
                agent.save_model(MODEL_SAVE_PATH)
    except KeyboardInterrupt:
        print("\nTraining interrupted by user.")
    finally:
        pbar.close()
        agent.save_model(MODEL_SAVE_PATH)
        log_file.close()
        print("Multi-agent training finished and model saved.")



Using device: cpu


Multi-agent Training...:   2%|█                                                         | 9/500 [00:00<00:18, 26.20ep/s]

一人が学習済みデータを使用

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from collections import deque
import time
import os
import math
from tqdm import tqdm

# --- 保存先ディレクトリの設定（環境に合わせて変更） ---
save_dir = r"学習データ"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# =========================
# 1. サイコロクラス（回転回数カウント付き）
# =========================
class Dice:
    def __init__(self, pos):
        top = random.randint(1, 6)
        bottom = 7 - top
        candidates = [n for n in range(1, 7) if n not in [top, bottom]]
        left = random.choice(candidates)
        right = 7 - left
        if left in candidates:
            candidates.remove(left)
        if not candidates:
            all_faces = set(range(1, 7))
            used_faces = {top, bottom, left, right}
            candidates = list(all_faces - used_faces)
            if not candidates:
                raise ValueError("Cannot determine front/back faces")
        front = random.choice(candidates)
        back = 7 - front
        self.orientation = (top, bottom, left, right, front, back)
        self.pos = pos  # (row, col)
        self.rotation_count = 0  # 回転回数のカウント用

    @property
    def top(self):
        return self.orientation[0]

    def roll(self, direction):
        # ロール毎に回転回数を1加算
        self.rotation_count += 1
        top, bottom, left, right, front, back = self.orientation
        if direction == 'up':
            self.orientation = (front, back, left, right, bottom, top)
        elif direction == 'down':
            self.orientation = (back, front, left, right, top, bottom)
        elif direction == 'right':
            self.orientation = (left, right, bottom, top, front, back)
        elif direction == 'left':
            self.orientation = (right, left, top, bottom, front, back)
        else:
            raise ValueError("無効な方向です")

# ================================
# 2. ４人対戦環境クラス（状態に現在のプレイヤー番号を追加）
# ================================
class PsychoDiceWarsEnv4:
    def __init__(self, board_size=6, max_turns=20):
        self.board_size = board_size
        self.max_turns = max_turns   # 1ゲーム＝20セクションとする
        self.reset()

    def reset(self):
        # 盤面は 6x6（各セル：0は未塗装、1～4は各プレイヤーの色）
        self.board = np.zeros((self.board_size, self.board_size), dtype=int)
        self.turn_count = 0
        self.current_player = 1  # プレイヤー1から開始
        self.dice = {}           # プレイヤー番号1～4のダイス管理
        self.moves_remaining = {}
        self.done = False
        # 重複しない位置をランダムに選び、4人の初期位置とする
        positions = set()
        while len(positions) < 4:
            pos = (random.randint(0, self.board_size - 1), random.randint(0, self.board_size - 1))
            positions.add(pos)
        positions = list(positions)
        for player in range(1, 5):
            self.dice[player] = Dice(positions[player - 1])
            self.board[self.dice[player].pos] = player
            self.moves_remaining[player] = self.dice[player].top
        return self.get_state()

    def get_state(self):
        # 状態ベクトルは以下の要素で構成（全51要素）：
        # - 盤面 (36)
        # - プレイヤー1のダイス位置 (2)
        # - 他3人のダイス位置 (6)
        # - プレイヤー1の残移動回数 (1)
        # - プレイヤー1のダイス上面 (1)
        # - 他3人のダイス上面 (3)
        # - ターン数の正規化 (1)
        # - 現在のプレイヤー番号 (1)
        board_flat = self.board.flatten().astype(np.float32)
        pos_self = np.array(self.dice[1].pos, dtype=np.float32) if self.dice.get(1) else np.array([-1, -1], dtype=np.float32)
        pos_opponents = []
        for p in [2, 3, 4]:
            if self.dice.get(p):
                pos_opponents.extend(self.dice[p].pos)
            else:
                pos_opponents.extend([-1, -1])
        pos_opponents = np.array(pos_opponents, dtype=np.float32)
        moves = np.array([self.moves_remaining.get(1, 0)], dtype=np.float32)
        dice_top_self = np.array([self.dice[1].top], dtype=np.float32) if self.dice.get(1) else np.array([0], dtype=np.float32)
        dice_top_opponents = []
        for p in [2, 3, 4]:
            dice_top_opponents.append(self.dice[p].top if self.dice.get(p) else 0)
        dice_top_opponents = np.array(dice_top_opponents, dtype=np.float32)
        turn_norm = np.array([self.turn_count / self.max_turns], dtype=np.float32)
        current_player_arr = np.array([self.current_player], dtype=np.float32)
        state = np.concatenate([
            board_flat, pos_self, pos_opponents, moves,
            dice_top_self, dice_top_opponents, turn_norm, current_player_arr
        ])
        if state.shape[0] != 51:
            state = np.pad(state, (0, 51 - state.shape[0]), 'constant', constant_values=0.0)[:51]
        state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
        return state.astype(np.float32)

    def print_board(self):
        # 表示マッピング：0→「・」、1→「〇」、2→「△」、3→「◇」、4→「☆」
        symbol = {0:"・", 1:"〇", 2:"△", 3:"◇", 4:"☆"}
        print("-" * (self.board_size * 3))
        for r in range(self.board_size):
            row_str = ""
            for c in range(self.board_size):
                row_str += symbol.get(self.board[r, c], "?") + "  "
            print(row_str)
        print("-" * (self.board_size * 3))
        print(f"Turn: {self.turn_count}/{self.max_turns} | Current Player: {self.current_player} | Moves Left: {self.moves_remaining.get(self.current_player, 0)}")

    def valid_move(self, pos, direction):
        r, c = pos
        dr, dc = 0, 0
        if direction == 'up':
            dr = -1
        elif direction == 'down':
            dr = 1
        elif direction == 'left':
            dc = -1
        elif direction == 'right':
            dc = 1
        next_r, next_c = r + dr, c + dc
        return (0 <= next_r < self.board_size and 0 <= next_c < self.board_size), (next_r, next_c)

    def step(self, action):
        actions = ['up', 'down', 'left', 'right']
        chosen_dir = actions[action]
        player = self.current_player
        dice = self.dice[player]
        orig_pos = dice.pos
        reward = 0.0
        info = {'battle_result': 'none', 'turn_ended': False, 'final_score': None}
        try:
            dice.roll(chosen_dir)
        except ValueError as e:
            print(f"Dice roll error: {e}")
            reward -= 10
            self.moves_remaining[player] = 0
            info['error'] = 'Invalid roll direction'
        valid, next_pos = self.valid_move(orig_pos, chosen_dir)
        if valid:
            target_owner = self.board[next_pos]
            dice.pos = next_pos
            self.board[dice.pos] = player
            move_reward = 0.1 if target_owner == 0 else (0.3 if target_owner != player else 0.05)
            reward += move_reward
        else:
            reward -= 0.5
        self.moves_remaining[player] -= 1
        turn_end = (self.moves_remaining[player] <= 0)
        if turn_end:
            info['turn_ended'] = True
            # キル判定：同じセルに他プレイヤーのダイスがあればキル処理
            for opp in range(1, 5):
                if opp == player:
                    continue
                if self.dice.get(opp) and dice.pos == self.dice[opp].pos:
                    try:
                        self._kill(opp, dice.pos)
                        reward += 5.0
                        info['battle_result'] = 'kill'
                    except Exception as e:
                        print(f"Kill error: {e}")
                        reward -= 10
                        info['error'] = 'Kill failed'
            if dice.top == 6:
                reward += 0.5
            for opp in range(1, 5):
                if opp == player or not self.dice.get(opp):
                    continue
                dx = abs(dice.pos[0] - self.dice[opp].pos[0])
                dy = abs(dice.pos[1] - self.dice[opp].pos[1])
                if dx + dy == 1:
                    reward -= 1.0
                    break
            next_player = (player % 4) + 1
            self.current_player = next_player
            if next_player == 1:
                self.turn_count += 1
            if self.dice.get(next_player):
                self.moves_remaining[next_player] = self.dice[next_player].top
            else:
                self.moves_remaining[next_player] = 0
        if (not self.done) and (self.turn_count >= self.max_turns):
            self.done = True
            scores = {p: np.sum(self.board == p) for p in range(1, 5)}
            info['final_score'] = scores
            if scores[1] > max(scores[q] for q in [2, 3, 4]):
                final_reward = (scores[1] - max(scores[q] for q in [2, 3, 4])) * 0.1 + 15.0
            elif scores[1] < max(scores[q] for q in [2, 3, 4]):
                final_reward = (scores[1] - max(scores[q] for q in [2, 3, 4])) * 0.1 - 15.0
            else:
                final_reward = 0.0
            reward += final_reward
        next_state = self.get_state()
        if np.isnan(reward) or np.isinf(reward):
            reward = 0.0
        if np.isnan(next_state).any() or np.isinf(next_state).any():
            next_state = np.nan_to_num(next_state, nan=0.0, posinf=0.0, neginf=0.0)
        return next_state.astype(np.float32), float(reward), self.done, info

    def _kill(self, player_to_kill, killer_pos):
        killer_color = self.current_player  # キルを行ったプレイヤーの色
        r_kill, c_kill = killer_pos
        for dr in [-1, 0, 1]:
            for dc in [-1, 0, 1]:
                nr, nc = r_kill + dr, c_kill + dc
                if 0 <= nr < self.board_size and 0 <= nc < self.board_size:
                    self.board[nr, nc] = killer_color
        respawn_attempts = 0
        max_attempts = self.board_size * self.board_size * 2
        killer_dice_pos = self.dice[self.current_player].pos if self.dice.get(self.current_player) else None
        new_pos = None
        while respawn_attempts < max_attempts:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            candidate = (new_r, new_c)
            if candidate != killer_dice_pos:
                new_pos = candidate
                break
            respawn_attempts += 1
        if new_pos is None:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            new_pos = (new_r, new_c)
        self.dice[player_to_kill] = Dice(new_pos)
        self.board[new_pos] = player_to_kill
        self.moves_remaining[player_to_kill] = self.dice[player_to_kill].top

# ========================================
# 3. DQNネットワークと統合エージェント（全プレイヤー共通のモデル：入力次元 51）
# ========================================
class DQN(nn.Module):
    def __init__(self, input_dim=51, output_dim=4):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.float()
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        return self.net(x)

class DQNAgent:
    def __init__(self, input_dim=51, output_dim=4,
                 lr=5e-5, gamma=0.99, epsilon=1.0, epsilon_min=0.05,
                 epsilon_decay=0.999995, buffer_size=1_000_000, batch_size=256,
                 target_update_freq=1000, lr_scheduler_gamma=0.9998):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = DQN(input_dim, output_dim).to(self.device)
        self.target_model = DQN(input_dim, output_dim).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, amsgrad=True)
        self.scheduler = ExponentialLR(self.optimizer, gamma=lr_scheduler_gamma)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.memory = deque(maxlen=buffer_size)
        self.target_update_freq = target_update_freq
        self.train_step_counter = 0

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.output_dim - 1)
        else:
            if not isinstance(state, np.ndarray):
                state = np.array(state, dtype=np.float32)
            if np.isnan(state).any() or np.isinf(state).any():
                state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
            try:
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                self.model.eval()
                with torch.no_grad():
                    q_values = self.model(state_tensor)
                self.model.train()
                if torch.isnan(q_values).any() or torch.isinf(q_values).any():
                    return random.randint(0, self.output_dim - 1)
                return int(torch.argmax(q_values).item())
            except Exception as e:
                return random.randint(0, self.output_dim - 1)

    def store_transition(self, state, action, reward, next_state, done):
        state = np.nan_to_num(np.array(state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        next_state = np.nan_to_num(np.array(next_state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        reward = np.float32(np.nan_to_num(reward, nan=0.0, posinf=0.0, neginf=0.0))
        self.memory.append((state, int(action), reward, next_state, bool(done)))

    def train_step(self):
        if len(self.memory) < self.batch_size * 10:
            return 0.0
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.BoolTensor(dones).unsqueeze(1).to(self.device)

        self.model.eval()
        self.target_model.eval()
        with torch.no_grad():
            next_q_online = self.model(next_states)
            best_next_actions = next_q_online.argmax(dim=1, keepdim=True)
            next_q_target = self.target_model(next_states)
            next_q_value = next_q_target.gather(1, best_next_actions)
            target_q = rewards + self.gamma * next_q_value * (~dones)
        self.model.train()
        current_q = self.model(states).gather(1, actions)
        loss = nn.SmoothL1Loss()(current_q, target_q.detach())
        if torch.isnan(loss) or torch.isinf(loss):
            return 0.0
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_step_counter += 1
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        if self.train_step_counter % self.target_update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        return loss.item()

    def update_scheduler(self):
        self.scheduler.step()

    def save_model(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        if os.path.exists(path):
            state_dict = torch.load(path, map_location=self.device)
            self.model.load_state_dict(state_dict)
            self.target_model.load_state_dict(self.model.state_dict())
            print(f"Model loaded from {path}")
        else:
            print(f"Model file not found at {path}. Starting from scratch.")

# ========================================
# 4. 統計を出力するシミュレーション関数（対象プレイヤーについて学習済みエージェントを使用）
# ========================================
def simulate_single_trained_match_for_player(target_player, num_matches=1000, model_path="dqn_model_final.pth", block=1000):
    """
    対象プレイヤー（target_player: 1～4）が学習済みエージェントの方策で行動する対戦を
    num_matches 試合実施。他のプレイヤーはランダム行動。
    各試合終了時に対象プレイヤーの色支配率、勝敗、キル数、回転数を記録し、
    block 試合毎に統計を出力、最終的に全体の統計を返す。
    ※ 各試合終了時にその試合の盤面（6×6）を表示する機能を追加。
    """
    env = PsychoDiceWarsEnv4(board_size=6, max_turns=20)
    agent = DQNAgent(input_dim=51, output_dim=4)
    agent.load_model(model_path)
    agent.epsilon = 0.0  # 固定方策

    total_control = 0.0
    total_wins = 0
    total_kills = 0
    total_rotations = 0

    for match in range(1, num_matches + 1):
        state = env.reset()
        done = False
        match_kills = 0
        match_rotations = 0  # 対象プレイヤーの行動回数＝回転数としてカウント
        while not done:
            if env.current_player == target_player:
                action = agent.choose_action(state)
                match_rotations += 1
            else:
                action = random.randint(0, 3)
            state, reward, done, info = env.step(action)
            # 対象プレイヤーのターンに「kill」が発生したと判断：
            # 次のプレイヤーが (target_player % 4) + 1 になったタイミングでカウント
            if info.get("battle_result") == "kill" and env.current_player == ((target_player % 4) + 1):
                match_kills += 1
        control = np.sum(env.board == target_player) / (env.board_size ** 2) * 100
        total_control += control
        total_kills += match_kills
        total_rotations += match_rotations
        scores = {p: np.sum(env.board == p) for p in range(1, 5)}
        if scores[target_player] == max(scores.values()):
            win = 1
        else:
            win = 0
        total_wins += win

        # 各試合終了時に盤面を表示（6×6）
        print(f"\n=== 試合 {match} の盤面 ===")
        env.print_board()

        if match % block == 0:
            avg_control = total_control / match
            win_rate = total_wins / match * 100
            avg_kills = total_kills / match
            avg_rotations = total_rotations / match
            print(f"\n【プレイヤー {target_player}：{match} 試合までの中間統計】")
            print(f"平均色支配率: {avg_control:.2f}%")
            print(f"勝率: {win_rate:.2f}%")
            print(f"平均キル数: {avg_kills:.2f}")
            print(f"平均回転数: {avg_rotations:.2f}")
            print("=" * 30)

    avg_control = total_control / num_matches
    win_rate = total_wins / num_matches * 100
    avg_kills = total_kills / num_matches
    avg_rotations = total_rotations / num_matches

    # 最終試合終了時も盤面表示
    print("\n=== 最終試合の盤面（6x6）の表示 ===")
    env.print_board()

    return avg_control, win_rate, avg_kills, avg_rotations

if __name__ == '__main__':
    # 学習済みモデルのパス（例："dqn_model_final.pth"）
    MODEL_PATH = r"dqn_model_final.pth"
    all_stats = {}
    # 対象プレイヤーごとに 1000 試合実施
    for p in [1, 2, 3, 4]:
        print(f"\n--- プレイヤー {p} のシミュレーション開始 ---")
        stats = simulate_single_trained_match_for_player(target_player=p, num_matches=1000, model_path=MODEL_PATH, block=1000)
        all_stats[p] = stats

    print("\n=== 各プレイヤーの最終統計 ===")
    for p in [1, 2, 3, 4]:
        avg_control, win_rate, avg_kills, avg_rotations = all_stats[p]
        print(f"プレイヤー {p}:")
        print(f"  平均色支配率: {avg_control:.2f}%")
        print(f"  勝率: {win_rate:.2f}%")
        print(f"  平均キル数: {avg_kills:.2f}")
        print(f"  平均回転数: {avg_rotations:.2f}")
        print("-" * 30)


[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
◇  △  △  △  ◇  ◇  
◇  △  △  △  〇  ◇  
△  △  △  △  ◇  ◇  
△  △  △  △  △  ◇  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 3

=== 試合 551 の盤面 ===
------------------
・  ◇  ◇  ◇  ◇  ☆  
・  ◇  ◇  ◇  ◇  ☆  
△  ◇  ◇  ◇  ◇  ☆  
△  △  △  〇  △  ☆  
〇  △  △  △  〇  △  
〇  △  △  △  △  △  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 6

=== 試合 552 の盤面 ===
------------------
・  ・  △  ◇  ◇  △  
・  ◇  ・  ◇  ◇  ◇  
・  ◇  ・  △  ◇  ◇  
〇  〇  〇  ・  △  ◇  
〇  〇  〇  〇  △  ◇  
〇  〇  〇  ◇  ◇  ◇  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 5

=== 試合 553 の盤面 ===
------------------
◇  ◇  ◇  ☆  ・  △  
◇  ◇  ◇  ☆  ◇  △  
◇  ◇  △  ☆  〇  △  
◇  ◇  △  ☆  〇  〇  
△  △  △  〇  〇  〇  
△  〇  〇  〇  〇  〇  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 1

=== 試合 554 の盤面 ===
------------------
☆  ☆  △  △  △  △  
☆  〇  ◇  〇  △  △  
☆  〇  ◇  〇  △  ◇  
◇  〇  ◇  ◇  △  ◇  
・  〇  〇  ◇  ◇  ◇  
・  〇  〇  〇  〇  〇 

４人が学習済みデータを使用

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from collections import deque
import time
import os
import math
from tqdm import tqdm

# --- 保存先ディレクトリの設定（環境に合わせて変更） ---
save_dir = r"学習データ"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# =========================
# 1. サイコロクラス（回転回数カウント付き）
# =========================
class Dice:
    def __init__(self, pos):
        top = random.randint(1, 6)
        bottom = 7 - top
        candidates = [n for n in range(1, 7) if n not in [top, bottom]]
        left = random.choice(candidates)
        right = 7 - left
        if left in candidates:
            candidates.remove(left)
        if not candidates:
            all_faces = set(range(1, 7))
            used_faces = {top, bottom, left, right}
            candidates = list(all_faces - used_faces)
            if not candidates:
                raise ValueError("Cannot determine front/back faces")
        front = random.choice(candidates)
        back = 7 - front
        self.orientation = (top, bottom, left, right, front, back)
        self.pos = pos  # (row, col)
        self.rotation_count = 0  # 回転回数のカウント用

    @property
    def top(self):
        return self.orientation[0]

    def roll(self, direction):
        # 毎回ロールするたびに回転カウントを1加算
        self.rotation_count += 1
        top, bottom, left, right, front, back = self.orientation
        if direction == 'up':
            self.orientation = (front, back, left, right, bottom, top)
        elif direction == 'down':
            self.orientation = (back, front, left, right, top, bottom)
        elif direction == 'right':
            self.orientation = (left, right, bottom, top, front, back)
        elif direction == 'left':
            self.orientation = (right, left, top, bottom, front, back)
        else:
            raise ValueError("無効な方向です")

# ================================
# 2. ４人対戦環境クラス（状態に現在のプレイヤー番号を追加）
# ================================
class PsychoDiceWarsEnv4:
    def __init__(self, board_size=6, max_turns=20):
        self.board_size = board_size
        self.max_turns = max_turns   # 1ゲーム＝20セクション
        self.reset()

    def reset(self):
        # 盤面は 6x6、各セル 0（未塗装）、1～4（各プレイヤーの色）
        self.board = np.zeros((self.board_size, self.board_size), dtype=int)
        self.turn_count = 0
        self.current_player = 1  # プレイヤー1から開始
        self.dice = {}           # プレイヤー1～4のダイス管理
        self.moves_remaining = {}
        self.done = False
        # 重複しない位置に4人のダイスを配置
        positions = set()
        while len(positions) < 4:
            pos = (random.randint(0, self.board_size - 1), random.randint(0, self.board_size - 1))
            positions.add(pos)
        positions = list(positions)
        for player in range(1, 5):
            self.dice[player] = Dice(positions[player-1])
            self.board[self.dice[player].pos] = player
            self.moves_remaining[player] = self.dice[player].top
        return self.get_state()

    def get_state(self):
        # 状態：盤面 (36) + プレイヤー1位置 (2) + 他3人位置 (6) +
        # プレイヤー1の残移動回数 (1) + プレイヤー1のダイス上面 (1)
        # + 他3人のダイス上面 (3) + ターン正規化 (1) + 現在のプレイヤー番号 (1) = 51
        board_flat = self.board.flatten().astype(np.float32)
        pos_self = np.array(self.dice[1].pos, dtype=np.float32) if self.dice.get(1) else np.array([-1,-1], dtype=np.float32)
        pos_opponents = []
        for p in [2,3,4]:
            if self.dice.get(p):
                pos_opponents.extend(self.dice[p].pos)
            else:
                pos_opponents.extend([-1,-1])
        pos_opponents = np.array(pos_opponents, dtype=np.float32)
        moves = np.array([self.moves_remaining.get(1,0)], dtype=np.float32)
        dice_top_self = np.array([self.dice[1].top], dtype=np.float32) if self.dice.get(1) else np.array([0], dtype=np.float32)
        dice_top_opponents = []
        for p in [2,3,4]:
            dice_top_opponents.append(self.dice[p].top if self.dice.get(p) else 0)
        dice_top_opponents = np.array(dice_top_opponents, dtype=np.float32)
        turn_norm = np.array([self.turn_count/self.max_turns], dtype=np.float32)
        current_player_arr = np.array([self.current_player], dtype=np.float32)
        state = np.concatenate([board_flat, pos_self, pos_opponents, moves, dice_top_self, dice_top_opponents, turn_norm, current_player_arr])
        if state.shape[0] != 51:
            state = np.pad(state, (0,51 - state.shape[0]), 'constant', constant_values=0.0)[:51]
        state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
        return state.astype(np.float32)

    def print_board(self):
        # 表示マッピング：0→「・」、1→「〇」、2→「△」、3→「◇」、4→「☆」
        symbol = {0:"・", 1:"〇", 2:"△", 3:"◇", 4:"☆"}
        print("-" * (self.board_size * 3))
        for r in range(self.board_size):
            row_str = ""
            for c in range(self.board_size):
                row_str += symbol.get(self.board[r,c], "?") + "  "
            print(row_str)
        print("-" * (self.board_size * 3))
        print(f"Turn: {self.turn_count}/{self.max_turns} | Current Player: {self.current_player} | Moves Left: {self.moves_remaining.get(self.current_player,0)}")

    def valid_move(self, pos, direction):
        r, c = pos
        dr, dc = 0, 0
        if direction == 'up':
            dr = -1
        elif direction == 'down':
            dr = 1
        elif direction == 'left':
            dc = -1
        elif direction == 'right':
            dc = 1
        next_r, next_c = r + dr, c + dc
        return (0 <= next_r < self.board_size and 0 <= next_c < self.board_size), (next_r, next_c)

    def step(self, action):
        actions = ['up','down','left','right']
        chosen_dir = actions[action]
        player = self.current_player
        dice = self.dice[player]
        orig_pos = dice.pos
        reward = 0.0
        info = {'battle_result':'none', 'turn_ended':False, 'final_score':None}
        try:
            dice.roll(chosen_dir)
        except ValueError as e:
            print(f"Dice roll error: {e}")
            reward -= 10
            self.moves_remaining[player] = 0
            info['error'] = 'Invalid roll direction'
        valid, next_pos = self.valid_move(orig_pos, chosen_dir)
        if valid:
            target_owner = self.board[next_pos]
            dice.pos = next_pos
            self.board[dice.pos] = player
            move_reward = 0.1 if target_owner == 0 else (0.3 if target_owner != player else 0.05)
            reward += move_reward
        else:
            reward -= 0.5
        self.moves_remaining[player] -= 1
        turn_end = (self.moves_remaining[player] <= 0)
        if turn_end:
            info['turn_ended'] = True
            # キル判定：同じセルに他プレイヤーのダイスがある場合
            for opp in range(1,5):
                if opp == player:
                    continue
                if self.dice.get(opp) and dice.pos == self.dice[opp].pos:
                    try:
                        self._kill(opp, dice.pos)
                        reward += 5.0
                        info['battle_result'] = 'kill'
                    except Exception as e:
                        print(f"Kill error: {e}")
                        reward -= 10
                        info['error'] = 'Kill failed'
            if dice.top == 6:
                reward += 0.5
            for opp in range(1,5):
                if opp == player or not self.dice.get(opp):
                    continue
                dx = abs(dice.pos[0] - self.dice[opp].pos[0])
                dy = abs(dice.pos[1] - self.dice[opp].pos[1])
                if dx + dy == 1:
                    reward -= 1.0
                    break
            next_player = (player % 4) + 1
            self.current_player = next_player
            if next_player == 1:
                self.turn_count += 1
            if self.dice.get(next_player):
                self.moves_remaining[next_player] = self.dice[next_player].top
            else:
                self.moves_remaining[next_player] = 0
        if (not self.done) and (self.turn_count >= self.max_turns):
            self.done = True
            scores = {p: np.sum(self.board == p) for p in range(1,5)}
            info['final_score'] = scores
            # ※ この部分はプレイヤー１の調整例。必要に応じ変更可能。
            if scores[1] > max(scores[q] for q in [2,3,4]):
                final_reward = (scores[1] - max(scores[q] for q in [2,3,4])) * 0.1 + 15.0
            elif scores[1] < max(scores[q] for q in [2,3,4]):
                final_reward = (scores[1] - max(scores[q] for q in [2,3,4])) * 0.1 - 15.0
            else:
                final_reward = 0.0
            reward += final_reward
        next_state = self.get_state()
        if np.isnan(reward) or np.isinf(reward):
            reward = 0.0
        if np.isnan(next_state).any() or np.isinf(next_state).any():
            next_state = np.nan_to_num(next_state, nan=0.0, posinf=0.0, neginf=0.0)
        return next_state.astype(np.float32), float(reward), self.done, info

    def _kill(self, player_to_kill, killer_pos):
        killer_color = self.current_player
        r_kill, c_kill = killer_pos
        for dr in [-1, 0, 1]:
            for dc in [-1, 0, 1]:
                nr, nc = r_kill + dr, c_kill + dc
                if 0 <= nr < self.board_size and 0 <= nc < self.board_size:
                    self.board[nr, nc] = killer_color
        respawn_attempts = 0
        max_attempts = self.board_size * self.board_size * 2
        killer_dice_pos = self.dice[self.current_player].pos if self.dice.get(self.current_player) else None
        new_pos = None
        while respawn_attempts < max_attempts:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            candidate = (new_r, new_c)
            if candidate != killer_dice_pos:
                new_pos = candidate
                break
            respawn_attempts += 1
        if new_pos is None:
            new_r = random.randint(0, self.board_size - 1)
            new_c = random.randint(0, self.board_size - 1)
            new_pos = (new_r, new_c)
        self.dice[player_to_kill] = Dice(new_pos)
        self.board[new_pos] = player_to_kill
        self.moves_remaining[player_to_kill] = self.dice[player_to_kill].top

# ========================================
# 3. DQNネットワークと統合エージェント（全プレイヤー共通のモデル：入力次元 51）
# ========================================
class DQN(nn.Module):
    def __init__(self, input_dim=51, output_dim=4):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.float()
        if torch.isnan(x).any() or torch.isinf(x).any():
            x = torch.nan_to_num(x, nan=0.0, posinf=1e6, neginf=-1e6)
        return self.net(x)

class DQNAgent:
    def __init__(self, input_dim=51, output_dim=4,
                 lr=5e-5, gamma=0.99, epsilon=1.0, epsilon_min=0.05,
                 epsilon_decay=0.999995, buffer_size=1_000_000, batch_size=256,
                 target_update_freq=1000, lr_scheduler_gamma=0.9998):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.model = DQN(input_dim, output_dim).to(self.device)
        self.target_model = DQN(input_dim, output_dim).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr, amsgrad=True)
        self.scheduler = ExponentialLR(self.optimizer, gamma=lr_scheduler_gamma)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.memory = deque(maxlen=buffer_size)
        self.target_update_freq = target_update_freq
        self.train_step_counter = 0

    def choose_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.output_dim - 1)
        else:
            if not isinstance(state, np.ndarray):
                state = np.array(state, dtype=np.float32)
            if np.isnan(state).any() or np.isinf(state).any():
                state = np.nan_to_num(state, nan=0.0, posinf=0.0, neginf=0.0)
            try:
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                self.model.eval()
                with torch.no_grad():
                    q_values = self.model(state_tensor)
                self.model.train()
                if torch.isnan(q_values).any() or torch.isinf(q_values).any():
                    return random.randint(0, self.output_dim - 1)
                return int(torch.argmax(q_values).item())
            except Exception as e:
                return random.randint(0, self.output_dim - 1)

    def store_transition(self, state, action, reward, next_state, done):
        state = np.nan_to_num(np.array(state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        next_state = np.nan_to_num(np.array(next_state, dtype=np.float32), nan=0.0, posinf=0.0, neginf=0.0)
        reward = np.float32(np.nan_to_num(reward, nan=0.0, posinf=0.0, neginf=0.0))
        self.memory.append((state, int(action), reward, next_state, bool(done)))

    def train_step(self):
        if len(self.memory) < self.batch_size * 10:
            return 0.0
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.BoolTensor(dones).unsqueeze(1).to(self.device)

        self.model.eval()
        self.target_model.eval()
        with torch.no_grad():
            next_q_online = self.model(next_states)
            best_next_actions = next_q_online.argmax(dim=1, keepdim=True)
            next_q_target = self.target_model(next_states)
            next_q_value = next_q_target.gather(1, best_next_actions)
            target_q = rewards + self.gamma * next_q_value * (~dones)
        self.model.train()
        current_q = self.model(states).gather(1, actions)
        loss = nn.SmoothL1Loss()(current_q, target_q.detach())
        if torch.isnan(loss) or torch.isinf(loss):
            return 0.0
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        self.train_step_counter += 1
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        if self.train_step_counter % self.target_update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())
        return loss.item()

    def update_scheduler(self):
        self.scheduler.step()

    def save_model(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        torch.save(self.model.state_dict(), path)
        print(f"Model saved to {path}")

    def load_model(self, path):
        if os.path.exists(path):
            state_dict = torch.load(path, map_location=self.device)
            self.model.load_state_dict(state_dict)
            self.target_model.load_state_dict(self.model.state_dict())
            print(f"Model loaded from {path}")
        else:
            print(f"Model file not found at {path}. Starting from scratch.")

# ========================================
# 4. 全プレイヤーが学習済みエージェントを使用するシミュレーション関数（統計出力付き）
# ========================================
def simulate_all_trained_match_with_stats(num_matches=1000, model_path="dqn_multiagent_final.pth", block=1000):
    """
    全プレイヤーが学習済みモデルの方策で行動する対戦を num_matches 試合実施。
    各試合終了時、各プレイヤーの盤面上の色支配率、勝敗、キル数、回転数を記録し、
    block 試合毎に中間統計を出力、最終的に全体の統計を返す。
    さらに、各試合終了後にその試合の盤面（6×6）を表示します。
    """
    env = PsychoDiceWarsEnv4(board_size=6, max_turns=20)
    # 各プレイヤーごとにエージェントを生成（入力次元は51）
    agents = {}
    for p in [1, 2, 3, 4]:
        agents[p] = DQNAgent(input_dim=51, output_dim=4)
        agents[p].load_model(model_path)
        agents[p].epsilon = 0.0

    # 累積統計（各プレイヤーごとに）
    totals = {p: {"control": 0.0, "wins": 0, "kills": 0, "rotations": 0} for p in [1,2,3,4]}

    for match in range(1, num_matches+1):
        state = env.reset()
        done = False
        # 試合中の各プレイヤーのキル数および行動回数（回転数）を初期化
        match_kills = {p: 0 for p in [1,2,3,4]}
        match_rotations = {p: 0 for p in [1,2,3,4]}
        while not done:
            current_player = env.current_player
            action = agents[current_player].choose_action(state)
            # 現在行動しているプレイヤーの回転数をカウント
            match_rotations[current_player] += 1
            state, reward, done, info = env.step(action)
            # キルが発生した場合、行動したプレイヤーにキル数を付与
            if info.get("battle_result") == "kill":
                match_kills[current_player] += 1
        # 試合終了時、各プレイヤーの盤面上のセル数から色支配率を計算
        scores = {p: np.sum(env.board == p) for p in [1,2,3,4]}
        for p in [1,2,3,4]:
            control = scores[p] / (env.board_size**2) * 100
            totals[p]["control"] += control
            totals[p]["kills"] += match_kills[p]
            totals[p]["rotations"] += match_rotations[p]
            if scores[p] == max(scores.values()):
                totals[p]["wins"] += 1

        # 各試合終了後に盤面を出力
        print(f"\n=== 試合 {match} の盤面 ===")
        env.print_board()

        if match % block == 0:
            print(f"\n【全プレイヤー：{match}試合までの中間統計】")
            for p in [1,2,3,4]:
                avg_control = totals[p]["control"] / match
                win_rate = totals[p]["wins"] / match * 100
                avg_kills = totals[p]["kills"] / match
                avg_rotations = totals[p]["rotations"] / match
                print(f"プレイヤー {p}: 平均色支配率 {avg_control:.2f}%, 勝率 {win_rate:.2f}%, 平均キル数 {avg_kills:.2f}, 平均回転数 {avg_rotations:.2f}")
            print("=" * 40)
    final_stats = {}
    for p in [1,2,3,4]:
        avg_control = totals[p]["control"] / num_matches
        win_rate = totals[p]["wins"] / num_matches * 100
        avg_kills = totals[p]["kills"] / num_matches
        avg_rotations = totals[p]["rotations"] / num_matches
        final_stats[p] = (avg_control, win_rate, avg_kills, avg_rotations)
    return final_stats

if __name__ == '__main__':
    # ここで使用するモデルのパスを指定
    # ※ 複数プレイヤーを統合学習したモデルの例として dqn_multiagent_final.pth を使用
    MODEL_PATH = r"dqn_multiagent_final.pth"
    stats_all = simulate_all_trained_match_with_stats(num_matches=1000, model_path=MODEL_PATH, block=1000)
    print("\n=== 各プレイヤーの最終統計 ===")
    for p in [1,2,3,4]:
        avg_control, win_rate, avg_kills, avg_rotations = stats_all[p]
        print(f"プレイヤー {p}:")
        print(f"  平均色支配率: {avg_control:.2f}%")
        print(f"  勝率: {win_rate:.2f}%")
        print(f"  平均キル数: {avg_kills:.2f}")
        print(f"  平均回転数: {avg_rotations:.2f}")
        print("-" * 30)


[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
・  ・  △  ・  ・  ・  
・  ・  △  ・  ◇  ・  
・  ・  △  ・  ◇  ・  
・  ・  △  〇  ◇  〇  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 4

=== 試合 550 の盤面 ===
------------------
・  ☆  ・  ・  ・  ・  
〇  〇  〇  〇  〇  〇  
◇  ・  ・  ・  △  ・  
◇  ・  ・  ・  △  ・  
◇  ・  ・  ・  △  ・  
◇  ・  ・  ・  △  ・  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 1

=== 試合 551 の盤面 ===
------------------
・  〇  〇  ◇  ☆  ☆  
・  ・  △  ◇  ☆  ☆  
・  ・  △  〇  〇  〇  
・  ・  △  ・  ・  ☆  
・  ・  △  ・  ・  ・  
・  ・  △  ・  ・  ・  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 1

=== 試合 552 の盤面 ===
------------------
・  ・  △  ・  ☆  ・  
・  〇  △  〇  ☆  〇  
・  △  △  △  ☆  ・  
・  △  △  △  ☆  ・  
・  △  △  △  ☆  ・  
・  ・  △  ◇  ◇  ・  
------------------
Turn: 20/20 | Current Player: 1 | Moves Left: 4

=== 試合 553 の盤面 ===
------------------
☆  △  ・  ・  ◇  ◇  
・  △  〇  〇  〇  〇  
・  △  ・  ・  ・  ◇  
・  △  ・  ・  ・  ◇  
・  ・  ・  ・  ◇  ◇  
・  ・  ・  〇  ◇  ◇ 