# Google Colab用セットアップ

In [None]:
%cd /content/
!git clone https://github.com/nekoneko02/cat-brain.git
%cd cat-brain
!git checkout origin/cnn


In [None]:
%cd /content
!mv /content/cat-brain /content/cat_brain
!mv /content/cat_brain/cat-dqn /content/cat_brain/cat_dqn
!sed -i 's|\.\./cat-game/config/common\.json|/content/cat_brain/cat-game/config/common.json|g' /content/cat_brain/cat_dqn/cat_toy_env.py

# 強化学習モデルの学習 (main.py)

このセルでは、DQNアルゴリズムを用いて、`CartPole-v1`環境でモデルを学習させます。

In [None]:
!apt install cmake swig zlib1g-dev
%pip install torch torchvision
%pip install numpy onnx
%pip install pettingzoo[all]
%pip install torchrl
%pip install tensordict


In [None]:
from pettingzoo.test import api_test
from cat_toy_env import CatToyEnv
env_kwargs=dict(render_mode=None, max_steps=1000)

# 1個だけ環境を作る（並列ではなく）
env = CatToyEnv(**env_kwargs)
api_test(env, num_cycles=1000, verbose_progress=False)

In [None]:
import gymnasium as gym
import torch

from cat_toy_env import CatToyEnv
#from cat_brain.cat_dqn.cat_toy_env import CatToyEnv # Google Colab用

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
from tensordict import TensorDict
import torchrl.modules as rlnn
import numpy as np
import random
from collections import deque
import os
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
import importlib
import cat_toy_env

# モジュールを再読み込み
importlib.reload(cat_toy_env)

# クラスを再インポート
from cat_toy_env import CatToyEnv


In [None]:
num_iterations = 100
num_episodes_per_iteration = 1
num_steps_per_episode = 100000
# num_epoches = 1
# num_replays_per_episode = num_epoches * num_episodes_per_iteration * num_steps_per_episode
update_target_steps = 10
replay_interval = 6
buffer_size = 10000
batch_size = 64
sequence_length = 1

with open('../cat-game/config/common.json', 'r') as f:
  config = json.load(f)
v_max = config["model"]["v_max"]
v_min = config["model"]["v_min"]
num_atoms = config["model"]["num_atoms"]
hidden_dim = config["model"]["hidden_size"]
  

In [None]:
env_kwargs=dict(render_mode=None, max_steps = num_steps_per_episode)
# 1個だけ環境を作る
env_preview = CatToyEnv(**env_kwargs)

obs = env_preview.reset()

# 観測のshapeを確認
print("観測の形:", obs)
print("観測の中身:", obs)
# 学習用環境
env_learning = CatToyEnv(**env_kwargs)

In [None]:
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim, num_atoms=num_atoms, v_min=v_min, v_max=v_max, rnn_hidden_dim=hidden_dim):
        super(DQN, self).__init__()
        self.num_atoms = num_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.delta_z = (v_max - v_min) / (num_atoms - 1)
        self.z_support = torch.linspace(self.v_min, self.v_max, self.num_atoms)
        self.output_dim = output_dim

        # RNN層
        self.rnn = nn.GRU(input_dim, rnn_hidden_dim, batch_first=True)

        # 特徴抽出層
        self.feature = nn.Sequential(
            nn.LazyLinear(256),
            nn.ReLU(),
            nn.LazyLinear(256),
            nn.ReLU()
        )

        # 状態価値関数 V(s)
        self.value_stream = nn.Sequential(
            rlnn.NoisyLinear(256, 128),
            nn.ReLU(),
            rlnn.NoisyLinear(128, num_atoms)
        )

        # アドバンテージ関数 A(s, a)
        self.advantage_stream = nn.Sequential(
            rlnn.NoisyLinear(256, 128),
            nn.ReLU(),
            rlnn.NoisyLinear(128, output_dim * num_atoms)
        )

    def forward(self, x, hidden_state=None):
        # RNNの処理
        x, hidden_state = self.rnn(x, hidden_state)
        x = x[:, -1, :]  # 最後の出力のみを使用

        # 特徴抽出
        x = self.feature(x)

        value = self.value_stream(x).view(-1, 1, self.num_atoms)
        advantage = self.advantage_stream(x).view(-1, self.output_dim, self.num_atoms)

        # Distributional Q-values
        q_atoms = value + advantage - advantage.mean(dim=1, keepdim=True)
        q_atoms = q_atoms.view(-1, self.output_dim, self.num_atoms)

        # Apply softmax to get probabilities
        probabilities = F.softmax(q_atoms, dim=2)
        return probabilities, hidden_state # [batch_size, output_dim, num_atoms]

    def get_support(self):
        return self.z_support


class DQNAgent:
    def __init__(self, agent_name, env, learning_rate=1e-4, gamma=0.995, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995):
        self.agent_name = agent_name
        self.action_space = env.action_spaces[self.agent_name]
        self.state_shape = env.observation_spaces[self.agent_name].shape[0]
        
        self.model = DQN(self.state_shape, self.action_space.n).to(device)
        self.target_model = DQN(self.state_shape, self.action_space.n).to(device)

        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.learning_rate = learning_rate

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
        self.loss_fn = nn.MSELoss()

        self.memory = TensorDictPrioritizedReplayBuffer(storage=LazyTensorStorage(buffer_size), alpha=0.6, beta=0.4)
        self.batch_size = batch_size
        self.update_target_model()
        
        self.hidden_state = None

    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.add(TensorDict({
            'state': torch.FloatTensor(state),
            'action': torch.LongTensor([action]),
            'reward': torch.FloatTensor([reward]),
            'next_state': torch.FloatTensor(next_state),
            'done': torch.FloatTensor([done]),
            'td_error': 1.0 # 初期の誤差は1に設定
        }))

    def act(self, state):
        if random.random() <= self.epsilon:
            return self.action_space.sample()
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(device)  # バッチ次元, Sequence次元を追加
        probabilities, self.hidden_state = self.model(state, self.hidden_state)  # [batch_size, output_dim, num_atoms], hidden_state

        # 各アクションごとに期待Q値を計算
        q_values = torch.sum(probabilities * self.model.get_support(), dim=-1)  # [batch_size, output_dim]
        return torch.argmax(q_values).item()  # 最大Q値に基づいて行動を選択

    def reset_hidden_state(self):
        self.hidden_state = None

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        batch, info = self.memory.sample(self.batch_size, return_info=True)
        indices, weights = info['index'], info['_weight']
        weights = torch.FloatTensor(weights).to(device)  # Tensorに変換

        states = batch['state'].to(device)
        actions = batch['action'].to(device).squeeze()
        rewards = batch['reward'].to(device).squeeze()
        next_states = batch['next_state'].to(device)
        dones = batch['done'].to(device).squeeze()

        # 現在の分布の取得
        probabilities, _ = self.model(states, None)  # [batch_size, num_actions, num_atoms], hidden_state
        batch_size, num_actions, num_atoms = probabilities.shape

        batch_indices = torch.arange(batch_size, device=device)
        # 選択したアクションの分布を取得
        selected_probs = probabilities[batch_indices, actions] # [batch_size, num_atoms]

        # 次状態の分布の取得
        next_probabilities, _ = self.target_model(next_states, None)  # [batch_size, num_actions, num_atoms], hidden_state

        # 次状態の期待Q値の計算
        next_q_values = torch.sum(next_probabilities * self.model.get_support(), dim=-1)  # [batch_size, num_actions]
        next_actions = torch.argmax(next_q_values, dim=1)  # [batch_size]

        # 次状態の分布を選択
        next_dist = next_probabilities[batch_indices, next_actions]

        # Categorical Projection
        projected_distribution = self.project_distribution(rewards, dones, next_dist)

        # 損失計算 (クロスエントロピー損失)
        kl_div = F.kl_div(torch.log(selected_probs + 1e-8), projected_distribution, reduction='none').sum(dim=1)

        # 優先度の更新
        td_errors = kl_div.detach()
        # 優先度のクリッピング
        max_priority = 1e3  # 適宜調整
        td_errors = torch.clamp(td_errors, min=1.0, max=max_priority)
        batch.set("td_error", td_errors)
        self.memory.update_tensordict_priority(batch)

        # 損失計算（重み適用）
        loss = (weights * kl_div).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ε減少
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay


    def project_distribution(self, rewards, dones, next_dist):
        """
        Categorical Projection for C51 algorithm.

        Args:
            rewards (Tensor): [batch_size] - 報酬
            dones (Tensor): [batch_size] - 終端フラグ
            next_dist (Tensor): [batch_size, num_atoms] - 次状態の分布

        Returns:
            projected_distribution (Tensor): [batch_size, num_atoms] - プロジェクション後の分布
        """
        batch_size = rewards.size(0)
        z_support = self.model.get_support()  # [num_atoms]
        num_atoms = z_support.size(0)
        
        # 各要素の target_z を計算
        target_z = rewards.unsqueeze(1) + self.gamma * z_support.unsqueeze(0) * (1 - dones.unsqueeze(1))
        target_z = target_z.clamp(min=self.model.v_min, max=self.model.v_max)

        # インデックス計算
        b = (target_z - self.model.v_min) / self.model.delta_z
        l = b.floor().long()
        u = b.ceil().long()

        # 下限・上限のクリッピング (無効なインデックスを避けるため)
        l = l.clamp(0, num_atoms - 1)
        u = u.clamp(0, num_atoms - 1)

        # 分布の割り当て
        offset = torch.linspace(0, (batch_size - 1) * num_atoms, batch_size, device=device).long().unsqueeze(1)

        # 出力分布を初期化
        projected_distribution = torch.zeros((batch_size, num_atoms), device=device)
        
        # 下のインデックスに対して割り当て
        projected_distribution.view(-1).index_add_(
            0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
        )

        # 上のインデックスに対して割り当て
        projected_distribution.view(-1).index_add_(
            0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
        )

        return projected_distribution

    def save_model(self, filepath):
        torch.save(self.model.state_dict(), filepath)

    def load_model(self, filepath):
        self.model.load_state_dict(torch.load(filepath))
        self.target_model.load_state_dict(self.model.state_dict())


In [None]:
def train_dqn(agent_dict, env, num_iterations, num_episodes_per_iteration):
    total_rewards = {agent: 0.0 for agent in env.agents}
    steps = 0
    for iteration in range(num_iterations):
        for episode in range(num_episodes_per_iteration):
            obs = env.reset()
            seq_obs = {agent: deque(maxlen=sequence_length+1) for agent in env.agents} # (len(agents), sequence_length, state_dim)
            prev_action = {agent: None for agent in env.agents}
            prev_total_reward = {agent: 0.0 for agent in env.agents}
            hidden_states = {agent: None for agent in env.agents} # RNNの隠れ状態を初期化

            for agent in env.agent_iter():
                if agent == "dummy":
                    # dummyエージェントは行動しない
                    action = None
                    env.step(action)
                    continue

                obs, total_reward, terminated, truncated, _ = env.last()
                done = terminated or truncated
                seq_obs[agent].append(obs)

                if env.step_count/len(env.agents) > sequence_length+1: # 過去の状態を保存するために、sequence_length+1回以上のステップが必要
                    list_obs = list(seq_obs[agent])
                    # 前回行動の結果が今回のループで得られたので、ここで保存できる
                    agent_dict[agent].store_experience(
                        list_obs[0:-1],         # s
                        prev_action[agent],      # a
                        total_reward - prev_total_reward[agent],      # r (現在のループで得られた報酬)
                        list_obs[1:],                     # s' (次状態)
                        float(terminated)              # done
                    )
                    # ここでreplayを行う
                    if env.step_count % replay_interval == 0:
                        for replay_agent in ["cat", "toy"]:
                            agent_dict[replay_agent].replay()

                if done or env.step_count % 1000 == 0:
                    print(f"{agent} with  steps {env.step_count}, reward {total_reward - prev_total_reward[agent]: 2f}, action: {prev_action}, state is {obs}")


                if done:
                    action = None  # No action needed if agent is done
                    total_rewards[agent] += total_reward
                    steps += env.step_count
                else:
                    action = agent_dict[agent].act(obs)
                    agent_dict[agent].reset_hidden_state() # 行動を選択するたびにノイズをリセット

                env.step(action)

                prev_action[agent] = action  # 次の行動を更新
                prev_total_reward[agent] = total_reward # 次の報酬を更新

        # ログ出力
        if iteration % update_target_steps == 0:
            print(f"+++++++ Iteration {iteration}: " + ", ".join([f"{a}: {r / update_target_steps:.2f}" for a, r in total_rewards.items()]), steps / update_target_steps)
            total_rewards = {agent: 0.0 for agent in total_rewards.keys()}
            steps = 0

        # ターゲットネットワーク更新
        if iteration % update_target_steps == 0:
            for agent in agent_dict.values():
                agent.update_target_model()

def evaluate_model(agent_dict, eval_env, n_eval_episodes=10):
    reward_sums = {agent_name: [] for agent_name in agent_dict.keys()}

    for _ in range(n_eval_episodes):
        env = eval_env  # 環境がreset可能で、内部状態が共有でないと仮定
        env.reset()
        episode_rewards = {agent_name: 0.0 for agent_name in agent_dict.keys()}

        for agent in env.agent_iter():
            if agent == "dummy":
                # dummyエージェントは行動しない
                action = None
                env.step(action)
                continue
            obs, reward, termination, truncation, info = env.last()
            done = termination or truncation

            if done:
                action = None  # 終了したら行動不要
            else:
                action = agent_dict[agent].act(obs)  # 各エージェントに行動させる
                agent_dict[agent].reset_hidden_state()  # 行動を選択するたびにノイズをリセット

            env.step(action)
            episode_rewards[agent] += reward  # 各agentごとに報酬を記録

        for agent_name in reward_sums:
            reward_sums[agent_name].append(episode_rewards[agent_name])

    # 統計量（平均・標準偏差）を返す
    mean_std_rewards = {
        agent: (np.mean(rewards), np.std(rewards))
        for agent, rewards in reward_sums.items()
    }

    return mean_std_rewards

def save_dqn(agent_dict, base_path = "models"):
    os.makedirs(base_path, exist_ok=True)
    for agent_name, agent in agent_dict.items():
        filepath = os.path.join(base_path, f"{agent_name}_model.pth")
        agent.save_model(filepath)

def load_dqn(env, agents = ["cat", "toy"] , base_path = "models"):
    agent_dict = {}
    for agent_name in agents:
        filepath = os.path.join(base_path, f"{agent_name}_model.pth")
        agent = DQNAgent(agent_name, env)
        _ = agent.act(env.reset())
        agent.load_model(filepath)
        agent_dict[agent_name] = agent
    return agent_dict

In [None]:
# エージェントの作成
agent_dict = {
    agent_name: DQNAgent(agent_name, env_learning)
    for agent_name in env_learning.agents
}


In [None]:
# 学習
train_dqn(agent_dict, env_learning, num_iterations, num_episodes_per_iteration)


In [None]:
# 評価用環境
env_kwargs=dict(render_mode="human", max_steps=3000)
env_eval = CatToyEnv(**env_kwargs)

# モデル評価
mean_std_rewards = evaluate_model(agent_dict, env_eval, n_eval_episodes=1)
print(f"mean_reward: {mean_std_rewards['cat']} +/- {mean_std_rewards['toy']}")

In [None]:
# モデルの保存
save_dqn(agent_dict, "models")

In [None]:
"""
# Google Colab用 Artifact保存
%cd /content/cat_brain/cat_dqn
save_dqn(agent_dict, "models")
!git config --global user.email "taka.flemish.giant@gmail.com"
!git config --global user.name "nekoneko02"
!git pull
!git add models/*
!git commit -m "Model保存 from Google Colab"
!git push origin HEAD:google-colab-artifact
"""

In [None]:
# 入力の2つのTensorを結合
toy = torch.randn(1, 2)
cat = torch.randn(1, 2)
dum = torch.randn(1, 2)
hidden_state = torch.randn(1, 1, 64)
concat_input = torch.cat([toy, cat, dum], dim=1).unsqueeze(0)  # shape: (1, 1, obs_dim)

# エクスポート対象モデル（例: policyネットワーク）
# dummyの環境
env_kwargs=dict(render_mode="human", max_steps=1000)
env_dummy = CatToyEnv(**env_kwargs)

# モデルのロード
loaded_model = load_dqn(env_dummy, ["cat", "toy"], "models")
policy_net = loaded_model["cat"].model  # catエージェントのポリシーネットワークを取得

# ONNX エクスポート
torch.onnx.export(
    policy_net,
    (concat_input, hidden_state),  # RNN用の入力は (入力テンソル, 隠れ状態) とする
    "cat_dqn_policy.onnx",
    export_params=True,
    opset_version=11,
    input_names=["obs", "hidden_state"],
    output_names=["probabilities", "next_hidden_state"],
    dynamic_axes={
        "obs": {0: "batch_size"},  # 観測データのバッチ次元を可変に
        "hidden_state": {0: "batch_size"},  # 隠れ状態のバッチ次元を可変に
        "probabilities": {0: "batch_size"},
        "next_hidden_state": {0: "batch_size"}
    }
)

In [None]:
env_dummy = CatToyEnv(**env_kwargs)
obs = torch.FloatTensor(env_dummy.reset()).unsqueeze(0)
print("obs:", obs)
print(loaded_model["cat"].model(obs))

In [None]:
# 環境のクローズ
env_learning.close()
env_eval.close()