In [None]:
# ===============================================================
# Minimal Evo-Neuro Benchmark (Colab-ready, single cell)
# Models: CnidarianNerveNet / SegmentedGanglia / FishBrain / HumanExecutive
# Tasks : HD-Jellyfish (SL, 3-way) + Reversal (RL, 2AFC)
# ===============================================================

import math, random
from dataclasses import dataclass
from typing import Dict, Any, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from collections import deque

# -----------------------------
# Reproducibility
# -----------------------------
def set_seed(seed=0):
    import os, numpy as np, torch, random
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# -----------------------------
# Global sensory dimensions (fixed across models)
# -----------------------------
D_VISION, D_OLFACT, D_SOMATO, D_AUDIT, D_PROP = 128, 32, 60, 64, 16
INPUT_DIM = D_VISION + D_OLFACT + D_SOMATO + D_AUDIT + D_PROP

# -----------------------------
# Obs helper (all modalities; allow overwrite)
# -----------------------------
def make_obs(batch=1, device="cpu",
             vision: torch.Tensor=None,
             olfaction: torch.Tensor=None,
             somato: torch.Tensor=None,
             auditory: torch.Tensor=None,
             proprio: torch.Tensor=None) -> Dict[str, torch.Tensor]:
    obs = {
        "vision": torch.randn(batch, D_VISION, device=device),
        "olfaction": torch.randn(batch, D_OLFACT, device=device),
        "somatosensory": torch.randn(batch, D_SOMATO, device=device),
        "auditory": torch.randn(batch, D_AUDIT, device=device),
        "proprioception": torch.randn(batch, D_PROP, device=device),
    }
    if vision  is not None: obs["vision"] = vision
    if olfaction is not None: obs["olfaction"] = olfaction
    if somato  is not None: obs["somatosensory"] = somato
    if auditory is not None: obs["auditory"] = auditory
    if proprio is not None:  obs["proprioception"] = proprio
    return obs

# -----------------------------
# Tiny block
# -----------------------------
class MLP(nn.Module):
    def __init__(self, in_dim, hidden, out_dim, depth=2, act=nn.ReLU):
        super().__init__()
        layers = [nn.Linear(in_dim, hidden), act()]
        for _ in range(max(0, depth-2)):
            layers += [nn.Linear(hidden, hidden), act()]
        layers += [nn.Linear(hidden, out_dim)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

# ===============================================================
# Models (each forward returns dict(motor=Tensor[B, motor_dim]))
# ===============================================================

# 1) Cnidarian: fixed random features (+small plastic head)
class CnidarianNerveNet(nn.Module):
    """
    生物の拡散神経網の素朴近似：入力→固定ランダム射影→tanh→小さな可塑ヘッド。
    ユニバーサル近似器になりにくい容量に制限。
    """
    def __init__(self, motor_dim=8, feat_dim=48):
        super().__init__()
        with torch.no_grad():
            W = torch.randn(INPUT_DIM, feat_dim) / math.sqrt(INPUT_DIM)
            b = torch.zeros(feat_dim)
        self.register_buffer("W", W)
        self.register_buffer("b", b)
        self.head = nn.Linear(feat_dim, motor_dim)
    def forward(self, x):
        z = torch.cat([x["vision"], x["olfaction"], x["somatosensory"], x["auditory"], x["proprioception"]], -1)
        h = torch.tanh(z @ self.W + self.b)  # [B,feat]
        return {"motor": torch.tanh(self.head(h))}

# 2) Segmented ganglia: somatoを体節に分割し局所制御+全身座標
class SegmentedGanglia(nn.Module):
    def __init__(self, segments=6, motor_per_seg=2):
        super().__init__()
        assert D_SOMATO % segments == 0
        self.segments = segments
        self.local_dim = D_SOMATO // segments
        self.coord = MLP(D_VISION + D_OLFACT + D_AUDIT + D_PROP, 64, 16, depth=2)
        self.controllers = nn.ModuleList([
            MLP(self.local_dim + 16, 64, motor_per_seg, depth=2) for _ in range(segments)
        ])
    def forward(self, x):
        B = x["somatosensory"].size(0)
        s = x["somatosensory"].view(B, self.segments, self.local_dim)
        c = self.coord(torch.cat([x["vision"], x["olfaction"], x["auditory"], x["proprioception"]], -1))
        outs = [self.controllers[i](torch.cat([s[:,i,:], c], -1)) for i in range(self.segments)]
        motor = torch.tanh(torch.cat(outs, -1))  # [B, segments*motor_per_seg]
        return {"motor": motor}

# helper modules
class BasalGanglia(nn.Module):
    def __init__(self, in_dim, motor_dim):
        super().__init__()
        self.pi = MLP(in_dim, 128, motor_dim, depth=2)
        self.g  = MLP(in_dim, 64,  motor_dim, depth=2)
    def forward(self, z):
        return self.pi(z) * torch.softmax(self.g(z), -1)

class Cerebellum(nn.Module):
    def __init__(self, sensory_dim, motor_dim):
        super().__init__()
        self.fm = MLP(sensory_dim + motor_dim, 128, motor_dim, depth=2)
    def forward(self, sensory, intended):
        return 0.2*torch.tanh(self.fm(torch.cat([sensory, intended], -1)))

# 2) Segmented ganglia (修正版): 各体節が somato の局所情報だけで制御
class SegmentedGangliaRestricted(nn.Module):
    """
    修正版:
    - 各体節は自分の somatosensory 入力だけで motor を生成
    - グローバル情報 (vision, olfact, auditory, proprio) は利用できない
    - 本来の生物的 segmental ganglia に近く、Detour のような空間推論はできないはず
    """
    def __init__(self, segments=6, motor_per_seg=2):
        super().__init__()
        assert D_SOMATO % segments == 0
        self.segments = segments
        self.local_dim = D_SOMATO // segments
        self.controllers = nn.ModuleList([
            MLP(self.local_dim, 32, motor_per_seg, depth=2)  # 小さめ MLP
            for _ in range(segments)
        ])

    def forward(self, x):
        B = x["somatosensory"].size(0)
        s = x["somatosensory"].view(B, self.segments, self.local_dim)
        outs = [self.controllers[i](s[:, i, :]) for i in range(self.segments)]
        motor = torch.tanh(torch.cat(outs, -1))  # [B, segments*motor_per_seg]
        return {"motor": motor}

import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------------
# ReflexArc: 局所反射弓（somatosensory → motor）
# ---------------------------------------------
class ReflexArc(nn.Module):
    def __init__(self, sensory_dim, motor_dim, hidden=16):
        super().__init__()
        self.afferent = nn.Linear(sensory_dim, hidden)
        self.interneuron = nn.ReLU()
        self.efferent = nn.Linear(hidden, motor_dim)

    def forward(self, sensory_input):
        h = self.interneuron(self.afferent(sensory_input))
        motor = torch.tanh(self.efferent(h))
        return motor


# ---------------------------------------------
# BrachialGanglion: 腕ガングリオン
# ReflexArc と中央脳入力を学習可能パラメータで統合
# ---------------------------------------------
class BrachialGanglion(nn.Module):
    def __init__(self, sensory_dim, central_dim, motor_dim):
        super().__init__()
        self.reflex_arc = ReflexArc(sensory_dim, motor_dim)
        self.reflex_weight = nn.Parameter(torch.tensor(1.0))    # reflexシナプス強度
        self.central_fc = nn.Linear(central_dim, motor_dim)
        self.central_weight = nn.Parameter(torch.tensor(1.0))   # centralシナプス強度

    def forward(self, sensory_input, central_cmd):
        reflex_out = self.reflex_arc(sensory_input)
        central_out = torch.tanh(self.central_fc(central_cmd))
        # 学習可能な重みで加算統合
        motor = self.reflex_weight * reflex_out + self.central_weight * central_out
        return motor


# ---- 共通ユーティリティ ----
class EIBlock(nn.Module):
    """皮質カラムの極小モデル: 興奮Eと抑制Iの相互作用（1ステップ離散化）"""
    def __init__(self, d):
        super().__init__()
        self.W_e = nn.Linear(d, d, bias=False)  # E←E
        self.W_i = nn.Linear(d, d, bias=False)  # I←E
        self.U_e = nn.Linear(d, d, bias=False)  # E←I（抑制）
        self.alpha = nn.Parameter(torch.tensor(0.7))  # Eの慣性
        self.beta  = nn.Parameter(torch.tensor(0.5))  # Iの慣性
        self.ln_e = nn.LayerNorm(d); self.ln_i = nn.LayerNorm(d)

    def forward(self, e, i, inp=None):
        # inpは外部入力（視覚・体性感覚などの投射を想定）
        if inp is None: inp = 0.0
        e_new = self.ln_e(self.alpha*e + self.W_e(e) - self.U_e(i) + inp)
        i_new = self.ln_i(self.beta *i + self.W_i(e))
        return torch.relu(e_new), torch.relu(i_new)

class ThalamicRelay(nn.Module):
    """視床リレー核: 皮質E活動を受け、選択的に再入力（可変ゲイン）"""
    def __init__(self, d):
        super().__init__()
        self.relay = nn.Linear(d, d, bias=False)
        self.gain  = nn.Parameter(torch.tensor(0.8))

    def forward(self, cortical_e):
        return self.gain * self.relay(cortical_e)

class BasalGangliaGate(nn.Module):
    """BGゲート: コンテキストにより視床出力を選択的に通す（抑制的出力の近似）"""
    def __init__(self, d_ctx, d):
        super().__init__()
        self.policy = nn.Sequential(
            nn.Linear(d_ctx, d), nn.Tanh(),
            nn.Linear(d, d), nn.Sigmoid()  # 0~1のゲート
        )

    def forward(self, ctx, thalamus_out):
        g = self.policy(ctx)        # [B,d]
        return g * thalamus_out, g  # 出力とゲインマップ


class CephalopodBrainV3(nn.Module):
    """
    頭足類モデル（再入ループ付き）
    - optic lobe（視葉）: 視覚入力と垂直葉出力との再帰相互作用
    - vertical lobe（垂直葉）: 視覚＋体性感覚の連合、抑制競合付きHebbian学習
    - peduncle lobe（柄葉）: 小脳様前向きモデル（誤差補正）
    - brachial ganglia（腕神経節）: 反射経路＋中央コマンド統合
    """
    def __init__(self, arms=8, central_dim=64, motor_dim=4, re_loops=3):
        super().__init__()
        self.arms = arms
        self.central_dim = central_dim
        self.motor_dim = motor_dim
        self.re_loops = re_loops
        self.sensory_dim = D_SOMATO
        self.visual_dim = D_VISION

        # --- 視葉 ---
        self.optic_in = nn.Linear(self.visual_dim, central_dim)
        self.optic_rec = nn.Linear(central_dim, central_dim, bias=False)  # 再入ループ入力
        self.optic_act = nn.Tanh()

        # --- 垂直葉 ---
        self.vert_in = nn.Linear(central_dim + self.sensory_dim, central_dim)
        self.vert_self = nn.Linear(central_dim, central_dim, bias=False)   # 再帰
        self.vert_inhib = nn.Linear(central_dim, central_dim, bias=False)  # 抑制項
        self.beta = nn.Parameter(torch.tensor(0.3))   # Hebbian強度
        self.gamma = nn.Parameter(torch.tensor(0.5))  # 抑制強度
        self.vert_norm = nn.LayerNorm(central_dim)

        # --- 柄葉（小脳様） ---
        self.peduncle_lobe = nn.Sequential(
            nn.Linear(central_dim, central_dim),
            nn.Tanh(),
        )

        # --- 腕ガングリオン ---
        self.brachial_ganglia = nn.ModuleList([
            BrachialGanglion(self.sensory_dim, central_dim, motor_dim)
            for _ in range(self.arms)
        ])

        self.out_act = nn.Tanh()

    def forward(self, x):
        """
        x: {"vision": [B, Dv], "somatosensory": [B, Ds]}
        return: {"motor": [B, arms*motor_dim], "central_command": [B, Dc]}
        """
        B = x["vision"].size(0)

        # 初期視葉出力
        optic = torch.tanh(self.optic_in(x["vision"]))
        vert = torch.zeros_like(optic)

        # --- 再入ループ (optic <-> vertical) ---
        for _ in range(self.re_loops):
            # 垂直葉更新：視覚＋体性感覚＋再帰
            vert_in = torch.cat([optic, x["somatosensory"]], dim=-1)
            h = torch.tanh(self.vert_in(vert_in))
            inhib = self.vert_inhib(vert)   # 抑制信号
            vert = torch.tanh(self.vert_self(vert) + h - self.gamma * inhib)
            # Hebbian項（自己相関強調）
            vert = vert + self.beta * (vert * h)
            vert = self.vert_norm(vert)

            # 視葉更新（再入）
            optic = self.optic_act(self.optic_in(x["vision"]) + self.optic_rec(vert))

        # 柄葉（小脳様）
        central_cmd = self.peduncle_lobe(vert)

        # 腕反射経路（局所統合）
        motors = []
        for i in range(self.arms):
            mi = self.brachial_ganglia[i](x["somatosensory"], central_cmd)
            motors.append(mi)
        motor = self.out_act(torch.cat(motors, dim=-1))
        return {"motor": motor, "central_command": central_cmd}


# ===============================================================
# Fish (統一フロー版): optic tectum → ventral telencephalon (BG-like)
#                      → cerebellum → spinal cord (integration) → motor
#   ＋ 反射系: somatosensory/proprioception → spinal interneurons → motor
#   両経路は脊髄で学習的に調停（ゲーティング）される
# ===============================================================
class SpinalMixer(nn.Module):
    """
    脊髄での反射/随意の調停を抽象化:
      gate ∈ [0,1]^{motor_dim} を学習し、motor = tanh( gate*reflex + (1-gate)*volitional )
    gate は（文脈依存）= g(somato, proprio, reflex, volitional)
    """
    def __init__(self, motor_dim: int, ctx_dim: int):
        super().__init__()
        # gateを出す小さなMLP（文脈: somato+proprio と両ドライブを適度に見る）
        hidden = 128
        self.ctx_proj = MLP(ctx_dim, hidden, hidden, depth=2)
        self.g_head   = nn.Linear(hidden, motor_dim)
        # 初期は随意寄り(=gate小さめ)にしたい場合はbiasを負に初期化しても良い
        nn.init.constant_(self.g_head.bias, 0.0)

    def forward(self, reflex_drive, volitional_drive, ctx):
        h = self.ctx_proj(ctx)
        gate = torch.sigmoid(self.g_head(h))           # [B, motor_dim] in [0,1]
        motor = torch.tanh(gate * reflex_drive + (1.0 - gate) * volitional_drive)
        return motor, gate


class FishBrainV3(nn.Module):
    """
    解剖学マッピング:
      - optic_tectum (視蓋): 視覚統合 + ventral_telencephalonからの再入
      - ventral_telencephalon (前脳腹側部): 行動選択・ゲーティング
      - cerebellum (小脳様): 誤差補正・協調
      - spinal_interneurons: 反射ドライブ
      - spinal_mixer: 反射と随意の学習的調停
    改修点:
      - optic_tectum と ventral_telencephalon 間に再入ループを導入
      - ventral_telencephalon 出力を視蓋入力へ再投射（期待・文脈バイアス）
      - 反射経路は従来通り維持
    """
    def __init__(self, motor_dim=12, loops=2, hidden_dim=64):
        super().__init__()
        self.loops = loops
        self.hidden_dim = hidden_dim

        # --- 視蓋 (optic tectum) ---
        self.optic_in = nn.Linear(D_VISION, hidden_dim)
        self.optic_rec = nn.Linear(motor_dim, hidden_dim, bias=False)  # 再入投射 (BG出力)
        self.optic_act = nn.Tanh()

        # --- 前脳腹側部 (ventral telencephalon; BG様) ---
        self.bg_in = nn.Linear(hidden_dim + D_OLFACT + D_PROP, motor_dim)
        self.bg_gate = nn.Linear(motor_dim, motor_dim, bias=False)
        self.bg_act = nn.Tanh()

        # --- 小脳様補正 ---
        self.cerebellum = Cerebellum(D_SOMATO + D_PROP, motor_dim)

        # --- 反射経路 ---
        self.spinal_interneurons = MLP(D_SOMATO + D_PROP, 128, motor_dim, depth=2)

        # --- 脊髄統合 ---
        ctx_dim = D_SOMATO + D_PROP
        self.spinal_mixer = SpinalMixer(motor_dim=motor_dim, ctx_dim=ctx_dim)

    def forward(self, x):
        B = x["vision"].size(0)
        # 初期視蓋表象
        optic = self.optic_act(self.optic_in(x["vision"]))
        bg_out = torch.zeros(B, self.hidden_dim, device=optic.device)

        # --- optic–BG再入ループ ---
        for _ in range(self.loops):
            bg_in = torch.cat([optic, x["olfaction"], x["proprioception"]], dim=-1)
            bg_drive = self.bg_act(self.bg_in(bg_in))
            optic = self.optic_act(self.optic_in(x["vision"]) + self.optic_rec(bg_drive))

        # --- 小脳経路 ---
        z_cb = torch.cat([x["somatosensory"], x["proprioception"]], -1)
        volitional = bg_drive + self.cerebellum(z_cb, bg_drive)

        # --- 反射経路 ---
        reflex = self.spinal_interneurons(torch.cat([x["somatosensory"], x["proprioception"]], -1))

        # --- 脊髄統合 ---
        ctx = torch.cat([x["somatosensory"], x["proprioception"]], -1)
        motor, gate = self.spinal_mixer(reflex, volitional, ctx)

        return {"motor": motor, "gate": gate}


class HumanCortexV4(nn.Module):
    """
    解剖学マッピング:
      - 感覚統合皮質 (V/S/P/A)
      - PFC: 再帰的作業表象 + Thalamic再入 + BGゲート
      - 海馬様統合: 文脈保持
      - 小脳: 予測誤差補正
      - 脊髄反射＋統合: Reflex/Volitionalの調停
    改修:
      - PFC ↔ Thalamus の再入ループを導入
      - Thalamus出力はBGゲートを介してPFCへ再投射
      - Descending制御ゲインをSpinalMixerに追加
    """
    def __init__(self, motor_dim=20, d_emb=64, wm_dim=128, loops=2):
        super().__init__()
        self.motor_dim = motor_dim
        self.wm_dim = wm_dim
        self.d_emb = d_emb
        self.loops = loops

        # 感覚埋め込み
        self.tv = MLP(D_VISION, 128, d_emb, depth=2)
        self.ts = MLP(D_SOMATO, 128, d_emb, depth=2)
        self.tp = MLP(D_PROP,   64,  d_emb, depth=2)
        self.ta = MLP(D_AUDIT,  64,  d_emb, depth=2)

        # PFC (再入対象)
        self.pfc_in = nn.Linear(d_emb * 4, wm_dim)
        self.pfc_rec = nn.Linear(wm_dim, wm_dim, bias=False)
        self.pfc_norm = nn.LayerNorm(wm_dim)

        # Thalamus + BG Gate
        self.thalamus = nn.Linear(wm_dim, wm_dim, bias=False)
        self.bg_gate = nn.Sequential(
            nn.Linear(D_SOMATO + D_PROP, wm_dim),
            nn.Tanh(),
            nn.Linear(wm_dim, wm_dim),
            nn.Sigmoid()
        )

        # 海馬様統合（既存）
        enc_layer = nn.TransformerEncoderLayer(d_model=wm_dim, nhead=4, dim_feedforward=256)
        self.hippocampal = nn.TransformerEncoder(enc_layer, num_layers=1)
        self.h_norm = nn.LayerNorm(wm_dim)

        # 小脳様補正
        self.cerebellum = Cerebellum(D_SOMATO + D_PROP, motor_dim)

        # 反射系
        self.spinal_reflex = MLP(D_SOMATO + D_PROP, 128, motor_dim, depth=2)
        self.spinal_mixer = SpinalMixer(motor_dim, D_SOMATO + D_PROP)

        # 下行性ゲイン
        self.desc_gain = nn.Sequential(nn.Linear(wm_dim, motor_dim), nn.Tanh())

    def forward(self, x):
        B = x["vision"].size(0)

        # 感覚統合
        s = torch.cat([
            self.tv(x["vision"]),
            self.ts(x["somatosensory"]),
            self.tp(x["proprioception"]),
            self.ta(x["auditory"])
        ], dim=-1)

        # PFC初期化
        pfc = torch.tanh(self.pfc_in(s))
        ctx_bg = torch.cat([x["somatosensory"], x["proprioception"]], -1)

        # --- 再入ループ (PFC↔Thalamus with BG gate) ---
        for _ in range(self.loops):
            th = self.thalamus(pfc)
            g = self.bg_gate(ctx_bg)
            pfc = torch.tanh(self.pfc_in(s) + self.pfc_rec(th * g))
            pfc = self.pfc_norm(pfc)

        # 海馬様統合
        mem = self.hippocampal(pfc.unsqueeze(0)).squeeze(0)
        mem = self.h_norm(mem)

        # 小脳予測補正
        cb_inp = torch.cat([x["somatosensory"], x["proprioception"]], -1)
        volitional = self.cerebellum(cb_inp, self.desc_gain(mem))

        # 反射ルート
        reflex = self.spinal_reflex(torch.cat([x["somatosensory"], x["proprioception"]], -1))

        # 脊髄統合
        ctx = torch.cat([x["somatosensory"], x["proprioception"]], -1)
        motor, gate = self.spinal_mixer(reflex, volitional, ctx)

        return {"motor": motor, "gate": gate, "wm": mem}


# ===============================================================
# Task: Peristalsis (蠕動波出力) — 回帰（MSE）
# 目標：各体節のモータ出力が位相差φのある正弦波を形成
# 文献背景：体節CPG/体節間結合（例：Friesen & Pearce 2007, leech locomotor circuits）
# ===============================================================
class PeristalsisDataset(torch.utils.data.Dataset):
    """
    回帰タスク：motor_target ∈ R^(motor_dim)
    入力obsは somato と proprio に時間・目標速度などを符号化（単純ノイズでもOK）。
    """
    def __init__(self, motor_dim, n_samples=6000, A=1.0, omega=0.4, phi=0.8, noise=0.05, device="cpu"):
        self.device = device
        self.motor_dim = motor_dim
        self.A, self.omega, self.phi = A, omega, phi
        self.noise = noise
        self.T = []
        self.targets = []
        Xv, Xo, Xa, Xp, Xs = [], [], [], [], []
        for _ in range(n_samples):
            t = random.uniform(0, 2*math.pi)
            target = torch.tensor(
                [A*math.sin(omega*t + phi*i) for i in range(motor_dim)],
                device=device, dtype=torch.float32
            )
            # 観測は最小限：somato/proprioに t, omega, phi を雑に埋める（モデル間で公平）
            v = torch.randn(D_VISION, device=device) * self.noise
            o = torch.randn(D_OLFACT, device=device) * self.noise
            a = torch.randn(D_AUDIT,  device=device) * self.noise
            p = torch.randn(D_PROP,   device=device) * self.noise
            s = torch.randn(D_SOMATO, device=device) * self.noise
            # t, omega, phi を少数の次元に書き込む
            p[:3] = torch.tensor([t, omega, phi], device=device)
            Xv.append(v); Xo.append(o); Xa.append(a); Xp.append(p); Xs.append(s)
            self.targets.append(target)
        self.V = torch.stack(Xv); self.O = torch.stack(Xo)
        self.Au = torch.stack(Xa); self.P = torch.stack(Xp); self.S = torch.stack(Xs)
        self.Y = torch.stack(self.targets)

    def __len__(self): return self.V.size(0)
    def __getitem__(self, idx):
        obs = {
            "vision": self.V[idx],
            "olfaction": self.O[idx],
            "auditory": self.Au[idx],
            "proprioception": self.P[idx],
            "somatosensory": self.S[idx],
        }
        return obs, self.Y[idx]

@torch.no_grad()
def phase_corr(y_pred, y_true):
    # 各体節系列の位相整合をざっくり測る（相関係数の平均）
    num = (y_pred * y_true).sum(-1)
    den = (y_pred.pow(2).sum(-1).sqrt() * y_true.pow(2).sum(-1).sqrt() + 1e-9)
    return (num/den).mean().item()

def train_peristalsis(base_model: nn.Module, device="cpu", epochs=5, batch_size=128, lr=1e-3):
    dummy = make_obs(batch=2, device=device)
    motor_dim = base_model(dummy)["motor"].size(-1)
    ds_tr = PeristalsisDataset(motor_dim, n_samples=6000, device=device)
    ds_ev = PeristalsisDataset(motor_dim, n_samples=800, device=device)
    dl = torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)

    opt = torch.optim.Adam(base_model.parameters(), lr=lr)
    curve = []  # (trial_idx, "accuracy") の形式に近づける
    trial_counter = 0

    for ep in range(1, epochs+1):
        for obs, target in dl:
            trial_counter += len(target)
            for k in obs: obs[k] = obs[k].to(device)
            target = target.to(device)
            pred = base_model(obs)["motor"]
            loss = F.mse_loss(pred, target)
            opt.zero_grad(); loss.backward(); opt.step()

            # phaseCorr を「精度指標」とみなして curve に記録
            acc_batch = phase_corr(pred, target)  # 0〜1 に近い値
            curve.append((trial_counter, acc_batch))
        print(f"[Peristalsis] epoch={ep} loss={loss.item():.4f}")

    # 評価
    base_model.eval()
    mse_tot, corr_tot, N = 0.0, 0.0, 0
    with torch.no_grad():
        for obs, target in ds_ev:
            for k in obs: obs[k] = obs[k].unsqueeze(0).to(device)
            target = target.unsqueeze(0).to(device)
            pred = base_model(obs)["motor"]
            mse_tot += F.mse_loss(pred, target).item()
            corr_tot += phase_corr(pred.squeeze(0), target.squeeze(0))
            N += 1
    mse_eval = mse_tot/max(1,N)
    corr_eval = corr_tot/max(1,N)
    print(f"[Peristalsis] eval MSE={mse_eval:.4f}, phaseCorr={corr_eval:.3f}")

    return mse_eval, corr_eval, curve

# ===============================================================
# Task: Local Reflex（体節接触反射）— 3クラス分類
# 文献背景：Bässler (1986) 他、stick insect などの局所反射制御
# ラベル: 0=左回避, 1=直進維持, 2=右回避（簡略化）
# ===============================================================
class LocalReflexDataset(torch.utils.data.Dataset):
    def __init__(self, n_samples=6000, segments=6, contact_p=0.4, noise=0.05, device="cpu"):
        self.device = device
        self.segments = segments
        assert D_SOMATO % segments == 0
        self.local_dim = D_SOMATO // segments
        Xv, Xo, Xa, Xp, Xs, Y = [], [], [], [], [], []
        for _ in range(n_samples):
            # どの体節に接触が入るか（複数もありうる）
            contacts = [1 if random.random()<contact_p else 0 for _ in range(segments)]
            # 左群（前半の体節）に接触が偏れば左回避、右群（後半）なら右回避、どちらもなければ直進
            left_sum  = sum(contacts[:segments//2])
            right_sum = sum(contacts[segments//2:])
            if left_sum>right_sum and left_sum>0: y = 2  # 左接触多 → 右回避
            elif right_sum>left_sum and right_sum>0: y = 0 # 右接触多 → 左回避
            else: y = 1  # 直進

            # somato 符号化：各体節ブロックの最初の1次元に接触強度、それ以外はノイズ
            s = torch.randn(D_SOMATO, device=device)*noise
            for i,c in enumerate(contacts):
                if c:
                    start = i*self.local_dim
                    s[start] = 1.0  # 接触フラグ

            obs = {
                "vision": torch.randn(D_VISION, device=device)*noise,   # 使わない
                "olfaction": torch.randn(D_OLFACT, device=device)*noise,
                "auditory": torch.randn(D_AUDIT, device=device)*noise,
                "proprioception": torch.randn(D_PROP, device=device)*noise,
                "somatosensory": s,
            }
            Xv.append(obs["vision"]); Xo.append(obs["olfaction"])
            Xa.append(obs["auditory"]); Xp.append(obs["proprioception"]); Xs.append(s)
            Y.append(y)
        self.V = torch.stack(Xv); self.O = torch.stack(Xo); self.Au = torch.stack(Xa)
        self.P = torch.stack(Xp); self.S = torch.stack(Xs)
        self.Y = torch.tensor(Y, dtype=torch.long, device=device)

    def __len__(self): return self.V.size(0)
    def __getitem__(self, idx):
        obs = {
            "vision": self.V[idx],
            "olfaction": self.O[idx],
            "auditory": self.Au[idx],
            "proprioception": self.P[idx],
            "somatosensory": self.S[idx],
        }
        return obs, self.Y[idx]




# ===============================================================
# Adapter: auto head to n_actions; optional tiny memory (off by default)
# ===============================================================
class ModelAdapter(nn.Module):
    def __init__(self, base: nn.Module, n_actions: int, use_memory: bool=False, mem_dim: int=64):
        super().__init__()
        self.base = base
        self.n_actions = n_actions
        self.use_memory = use_memory
        self.mem_dim = mem_dim
        self._head = None
        self._mem = None
        self._h = None
    def _lazy_init(self, motor_dim, device):
        if self.use_memory:
            self._mem = nn.GRUCell(motor_dim, self.mem_dim).to(device)
            self._head = nn.Linear(self.mem_dim, self.n_actions).to(device)
        else:
            self._head = nn.Linear(motor_dim, self.n_actions).to(device)
    def reset_memory(self, B=1, device="cpu"):
        if self.use_memory and self._mem is not None:
            self._h = torch.zeros(B, self.mem_dim, device=device)
    def forward(self, obs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        out = self.base(obs)
        motor = out["motor"]
        if self._head is None:
            self._lazy_init(motor.size(-1), motor.device)
        if self.use_memory:
            if (self._h is None) or (self._h.size(0) != motor.size(0)):
                self.reset_memory(B=motor.size(0), device=motor.device)
            self._h = self._mem(motor, self._h)
            logits = self._head(self._h)
        else:
            logits = self._head(motor)
        return logits, out

# ===============================================================
# Task A: HD-Jellyfish (high-dim cue; supervised 3-way)
# ===============================================================
V_BASE, S_BASE = 8, 16  # biologically plausible low-dim bases

class FrozenObsMixer(nn.Module):
    """Low-dim bases -> fixed linear lift to Vision128 / Somato60."""
    def __init__(self):
        super().__init__()
        with torch.no_grad():
            Wv = torch.randn(V_BASE, D_VISION) / math.sqrt(V_BASE)
            Ws = torch.randn(S_BASE, D_SOMATO) / math.sqrt(S_BASE)
        self.register_buffer("Wv", Wv)
        self.register_buffer("Ws", Ws)
    @torch.no_grad()
    def forward(self, v_base, s_base):
        return v_base @ self.Wv, s_base @ self.Ws

class HDJellyfishDataset(torch.utils.data.Dataset):
    """
    label: 0=左へ回避, 1=直進, 2=右へ回避
    - create low-dim cues on V_BASE and S_BASE, then lift to full dims.
    - other modalities are noise (same for all models).
    """
    def __init__(self, n_samples=8000, p_none=0.3, amp=(0.6,1.3), noise=0.1, device="cpu", mixer=None):
        self.device = device
        self.noise = noise
        self.mixer = mixer or FrozenObsMixer()
        Xv, Xs, Xo, Xa, Xp, Y = [], [], [], [], [], []
        for _ in range(n_samples):
            r = random.random()
            left, right = 0.0, 0.0
            if r < p_none:
                y = 1
            else:
                if random.random() < 0.5:
                    left = random.uniform(*amp);  y = 2  # 右回避
                else:
                    right= random.uniform(*amp);  y = 0  # 左回避
            v = torch.randn(V_BASE, device=device) * (noise*0.3)
            half = V_BASE//2
            if left>0:  v[:half]  += left
            if right>0: v[half:]  += right
            s = torch.randn(S_BASE, device=device)*(noise*0.3)
            halfs=S_BASE//2
            if left>0:  s[:halfs]  += 0.3*left
            if right>0: s[halfs:]  += 0.3*right
            V_full, S_full = self.mixer(v.unsqueeze(0), s.unsqueeze(0))
            V_full = V_full.squeeze(0) + noise*torch.randn(D_VISION, device=device)
            S_full = S_full.squeeze(0) + noise*torch.randn(D_SOMATO, device=device)
            O = noise*torch.randn(D_OLFACT, device=device)
            A = noise*torch.randn(D_AUDIT,  device=device)
            P = noise*torch.randn(D_PROP,   device=device)
            Xv.append(V_full); Xs.append(S_full); Xo.append(O); Xa.append(A); Xp.append(P); Y.append(y)
        self.V = torch.stack(Xv); self.S = torch.stack(Xs)
        self.O = torch.stack(Xo); self.A = torch.stack(Xa); self.P = torch.stack(Xp)
        self.Y = torch.tensor(Y, dtype=torch.long, device=device)
    def __len__(self): return self.V.size(0)
    def __getitem__(self, idx):
        obs = {
            "vision": self.V[idx],
            "olfaction": self.O[idx],
            "somatosensory": self.S[idx],
            "auditory": self.A[idx],
            "proprioception": self.P[idx],
        }
        return obs, self.Y[idx]

# ===============================================================
# Task B: Reversal Learning (2AFC RL)
# ===============================================================
class ReversalEnv:
    def __init__(self, dim=D_VISION, noise=0.8, rev_at=1500, device="cpu"):
        self.dim, self.noise, self.rev_at, self.device = dim, noise, rev_at, device
        self.mu = torch.stack([torch.randn(dim), torch.randn(dim)], 0).to(device)
        self.t = 0
    def step(self):
        cls = random.randint(0,1)
        v = self.mu[cls] + self.noise*torch.randn(self.dim, device=self.device)
        obs = make_obs(vision=v.unsqueeze(0), device=self.device)
        correct = cls if self.t < self.rev_at else 1 - cls
        self.t += 1
        return obs, correct

# ===============================================================
# Task: Detour (魚類空間認知タスクの再現)
# ===============================================================
# ===============================================================
# Hard Detour 課題（Cnidarianでは解けない設計）
# ===============================================================
class HardDetourDataset(torch.utils.data.Dataset):
    """
    ゴールと障害物を2次元座標に配置し、相対関係から行動を決める課題。
    label: 0=左, 1=直進, 2=右
    - Cnidarian のような線形モデルでは非線形規則を表現できず失敗するはず。
    """
    def __init__(self, n_samples=5000, field_size=5, noise=0.1, device="cpu"):
        self.device = device
        self.noise = noise
        self.samples = []
        for _ in range(n_samples):
            # ゴール座標（前方に配置）
            gx, gy = random.randint(-field_size, field_size), field_size
            # 障害物座標（ゴール手前にランダム配置）
            ox, oy = random.randint(-field_size, field_size), random.randint(1, field_size)

            # Visionベクトル初期化
            v = torch.zeros(D_VISION, device=device)

            # ゴールの位置を符号化（インデックスをmodで割り当て）
            v[(gx + field_size) % D_VISION] = 1.0
            # 障害物の位置を符号化（ゴールと異なるインデックスに負の信号）
            v[(ox + oy + field_size) % D_VISION] = -1.0

            # 正解ラベルの決定（非線形ルール）
            if abs(ox - gx) < 2 and oy < gy:
                # ゴール手前に障害物あり → 回り込みが必要
                if ox <= 0:
                    y = 2  # 障害物が左寄り → 右回り
                else:
                    y = 0  # 障害物が右寄り → 左回り
            else:
                # 障害物が進路にかかっていない → ゴール方向へ直進
                if gx < -1: y = 0
                elif gx > 1: y = 2
                else: y = 1

            # ノイズを加える
            v += noise * torch.randn_like(v)

            obs = {
                "vision": v,
                "olfaction": torch.randn(D_OLFACT, device=device) * noise,
                "auditory": torch.randn(D_AUDIT, device=device) * noise,
                "proprioception": torch.randn(D_PROP, device=device) * noise,
                "somatosensory": torch.randn(D_SOMATO, device=device) * noise,
            }
            self.samples.append((obs, y))

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

# ===============================================================
# Jellyfish (HD-Jellyfish 回避タスク)
# ===============================================================
from collections import deque

def train_hd_jellyfish(adapter: ModelAdapter, device="cpu", epochs=3, batch_size=128,
                       p_none=0.3, noise=0.1):
    mixer = FrozenObsMixer().to(device)
    ds_tr = HDJellyfishDataset(8000, p_none=p_none, noise=noise, device=device, mixer=mixer)
    ds_ev = HDJellyfishDataset(1000, p_none=p_none, noise=noise, device=device, mixer=mixer)
    dl = torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)

    opt = torch.optim.Adam(adapter.parameters(), lr=1e-3)
    adapter.train()

    curve = []  # (step, moving_avg_loss)
    step = 0
    window = 100
    loss_window = deque(maxlen=window)

    for ep in range(1, epochs+1):
        total, cnt = 0.0, 0
        for obs, y in dl:
            step += 1
            for k in obs: obs[k] = obs[k].to(device)
            y = y.to(device)

            logits, _ = adapter(obs)
            loss = F.cross_entropy(logits, y)

            opt.zero_grad(); loss.backward(); opt.step()

            total += loss.item()*y.size(0); cnt += y.size(0)

            # 移動平均でカーブを記録
            loss_window.append(loss.item())
            avg_loss = sum(loss_window)/len(loss_window)
            curve.append((step, avg_loss))

        print(f"[HD-Jellyfish] epoch={ep} loss={total/max(1,cnt):.3f}")

    # 評価
    adapter.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for i in range(len(ds_ev)):
            obs, y = ds_ev[i]
            for k in obs: obs[k] = obs[k].unsqueeze(0).to(device)
            y = int(y.item())
            logits, _ = adapter(obs)
            pred = logits.argmax(-1).item()
            correct += int(pred == y); total += 1
    acc = correct/max(1,total)
    print(f"[HD-Jellyfish] eval accuracy = {acc:.3f}")

    return acc, curve


# ===============================================================
# Reversal Learning (逐次強化学習タスク)
# ===============================================================
def train_reversal(adapter: ModelAdapter, steps=2000, lr=1e-3, device="cpu"):
    env = ReversalEnv(device=device)
    opt = torch.optim.Adam(adapter.parameters(), lr=lr)
    baseline = 0.0; beta=0.02

    pre_acc, post_acc = [], []
    curve = []  # (step, moving_avg_acc)

    adapter.train(); adapter.reset_memory(1, device)

    window = 100  # 移動平均用
    acc_window = deque(maxlen=window)

    for i in range(1, steps+1):
        obs, correct = env.step()
        logits, _ = adapter(obs)
        dist = Categorical(logits=logits)  # 温度調整するなら logits/τ
        act = dist.sample()
        rew = 1.0 if int(act) == correct else 0.0
        adv = rew - baseline

        loss = -(dist.log_prob(act) * adv)
        opt.zero_grad(); loss.backward(); opt.step()

        baseline = (1 - beta) * baseline + beta * rew

        # 精度記録
        acc = 1.0 if int(act) == correct else 0.0
        (pre_acc if i < env.rev_at else post_acc).append(acc)
        acc_window.append(acc)
        avg_acc = sum(acc_window) / len(acc_window)
        curve.append((i, avg_acc))

        if i % 500 == 0:
            print(f"[Reversal] step={i:4d} "
                  f"pre_acc={sum(pre_acc[-500:])/max(1,len(pre_acc[-500:])):.2f} "
                  f"post_acc={sum(post_acc[-500:])/max(1,len(post_acc[-500:])):.2f}")

    pa = sum(pre_acc[-500:]) / max(1, len(pre_acc[-500:]))
    po = sum(post_acc[-500:]) / max(1, len(post_acc[-500:]))

    return float(pa), float(po), curve


def train_detour(adapter: ModelAdapter, device="cpu", epochs=5, batch_size=128):
    ds_tr = HardDetourDataset(4000, device=device)   # ★ HardDetour に変更
    ds_ev = HardDetourDataset(1000, device=device)
    dl = torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)

    opt = torch.optim.Adam(adapter.parameters(), lr=1e-3)
    adapter.train()

    curve = []
    trial_counter = 0

    for ep in range(1, epochs+1):
        total, cnt = 0.0, 0
        for obs, y in dl:
            trial_counter += len(y)
            for k in obs:
                obs[k] = obs[k].to(device)
            y = y.to(device)

            logits, _ = adapter(obs)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()

            with torch.no_grad():
                pred = logits.argmax(-1)
                acc_batch = (pred == y).float().mean().item()
                curve.append((trial_counter, acc_batch))

            total += loss.item() * y.size(0)
            cnt += y.size(0)

        print(f"[HardDetour] epoch={ep} loss={total/max(1,cnt):.3f}")

    # 評価
    adapter.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for i in range(len(ds_ev)):
            obs, y = ds_ev[i]
            for k in obs:
                obs[k] = obs[k].unsqueeze(0).to(device)
            logits, _ = adapter(obs)
            pred = logits.argmax(-1).item()
            correct += int(pred == y)
            total += 1

    acc = correct/max(1,total)
    print(f"[HardDetour] eval accuracy = {acc:.3f}")
    return acc, curve


def train_local_reflex(adapter: ModelAdapter, device="cpu", epochs=4, batch_size=128, lr=1e-3):
    ds_tr = LocalReflexDataset(n_samples=6000, device=device)
    ds_ev = LocalReflexDataset(n_samples=1000, device=device)
    dl = torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)

    opt = torch.optim.Adam(adapter.parameters(), lr=lr)
    curve = []
    trial_counter = 0

    for ep in range(1, epochs+1):
        for obs, y in dl:
            trial_counter += len(y)
            for k in obs: obs[k] = obs[k].to(device)
            y = y.to(device)
            logits, _ = adapter(obs)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()

            with torch.no_grad():
                pred = logits.argmax(-1)
                acc_batch = (pred == y).float().mean().item()
                curve.append((trial_counter, acc_batch))

        print(f"[LocalReflex] epoch={ep} loss={loss.item():.3f}")

    # 評価
    adapter.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for obs, y in ds_ev:
            for k in obs: obs[k] = obs[k].unsqueeze(0).to(device)
            y = int(y.item())
            logits, _ = adapter(obs)
            pred = logits.argmax(-1).item()
            correct += int(pred == y); total += 1
    acc = correct/max(1,total)
    print(f"[LocalReflex] eval accuracy = {acc:.3f}")
    return acc, curve

# ================== RPM-Mini (2x2) ==================
class RPMMiniDataset(torch.utils.data.Dataset):
    """
    2x2のRaven最小版（行:ルール→合成）。6x6グリッドを128次元に射影してvisionへ。
    ルール: XNOR / XOR / COUNT のいずれか。選択肢は3つ。
    """
    def __init__(self, n=6000, grid=6, device="cpu", noise=0.05, mixer=None, seed=0):
        super().__init__()
        self.device=device; self.noise=noise; self.grid=grid
        torch.manual_seed(seed); random.seed(seed)
        self.mixer = mixer or FrozenObsMixer()
        self.samples=[]
        for _ in range(n):
            rule = random.choice(["XNOR","XOR","COUNT"])
            # タイルをバイナリ画像(6x6)で3つ作る（T00, T01, T10）。T11が欠損。
            T00 = (torch.rand(grid,grid,device=device)>0.65).float()
            # 次元ごとのトグル/シフト
            mask = (torch.rand_like(T00)>0.5).float()
            T01 = (T00 if rule=="COUNT" else torch.clamp(T00 + (1-mask),0,1))
            T10 = torch.clamp(T00*mask,0,1) if rule!="COUNT" else (torch.rand_like(T00)>0.65).float()

            if rule=="XNOR":
                T11 = 1.0 - torch.logical_xor(T01.bool(), T10.bool()).float()  # 一致＝1
            elif rule=="XOR":
                T11 = torch.logical_xor(T01.bool(), T10.bool()).float()
            else:  # COUNT: 個数保存（T11の1の数＝T01の1の数）
                ones_target = int(T01.sum().item())
                flat = torch.zeros(grid*grid, device=device)
                idx = torch.randperm(grid*grid, device=device)[:ones_target]
                flat[idx]=1.0
                T11 = flat.view(grid,grid)

            def vec(img):
                base = img.flatten().float()  # 36次元
                # 36→V_BASE(=8)へ圧縮 → 128へ固定射影（FrozenObsMixerに合わせる）
                # ここは単純に36→128へパディング&線形でも良いが、ノイズで多様性付与
                v = torch.zeros(V_BASE, device=device)
                take = min(V_BASE, base.numel())
                v[:take] = base[:take]
                V_full, _ = self.mixer(v.unsqueeze(0), torch.zeros(1,S_BASE,device=device))
                return (V_full.squeeze(0) + self.noise*torch.randn(D_VISION,device=device))

            # ビネット：T00,T01 / T10,  ?(T11)
            panel = torch.stack([vec(T00), vec(T01), vec(T10)],0)  # (3,128)
            vision = panel.mean(0)  # 簡単化：まとめて1ベクトルに符号化

            # 候補の生成（正解1つ＋ダミー2つ）
            correct_vec = vec(T11)
            wrong1 = vec(T11.roll(shifts=1, dims=0))  # 適当な擾乱
            wrong2 = vec(1.0 - T11)                   # 反転
            # 候補を結合しsomatoにエンコード（3候補×キー付与）
            choices = torch.stack([correct_vec, wrong1, wrong2],0)  # (3,128)
            # ソマトに簡易インデックス（3）とノイズ
            somato = torch.randn(D_SOMATO, device=device)*self.noise
            # 正解位置をシャッフル
            order = torch.randperm(3)
            y = int((order==0).nonzero()[0])  # 元の0番(correct_vec)が今どこか
            choices = choices[order]

            # 観測辞書
            obs = {
                "vision": vision,                               # 128
                "somatosensory": somato,                        # 60
                "olfaction": torch.randn(D_OLFACT,device=device)*self.noise,
                "auditory":  torch.randn(D_AUDIT, device=device)*self.noise,
                "proprioception": torch.randn(D_PROP,device=device)*self.noise,
            }
            # 3候補はauditoryに混ぜると紛らわしいので、proprioの先頭で軽く埋め込む案もあるが、
            # ここでは adapter が最終的に motor に集約するため choices をvisionに平均加算で提示
            obs["vision"] = obs["vision"] + 0.25*choices.mean(0)

            self.samples.append((obs, y))

    def __len__(self): return len(self.samples)
    def __getitem__(self, i): return self.samples[i]


def train_rpm_mini(adapter: ModelAdapter, device="cpu", epochs=3, batch_size=128, noise=0.05):
    ds_tr = RPMMiniDataset(6000, device=device, noise=noise)
    ds_ev = RPMMiniDataset(1000, device=device, noise=noise, seed=7)
    dl = torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)
    opt = torch.optim.Adam(adapter.parameters(), lr=1e-3)
    curve=[]; trial=0
    adapter.train()
    for ep in range(1, epochs+1):
        total=cnt=0
        for obs,y in dl:
            trial+=len(y)
            for k in obs: obs[k]=obs[k].to(device)
            y=y.to(device)
            logits,_=adapter(obs)
            loss=F.cross_entropy(logits,y)
            opt.zero_grad(); loss.backward(); opt.step()
            with torch.no_grad():
                acc=(logits.argmax(-1)==y).float().mean().item()
                curve.append((trial,acc))
            total+=loss.item()*y.size(0); cnt+=y.size(0)
        print(f"[RPM-Mini] epoch={ep} loss={total/max(1,cnt):.3f}")
    # eval
    adapter.eval()
    correct=total=0
    with torch.no_grad():
        for obs,y in ds_ev:
            for k in obs: obs[k]=obs[k].unsqueeze(0).to(device)
            y=int(y)
            logits,_=adapter(obs)
            pred=logits.argmax(-1).item()
            correct+=int(pred==y); total+=1
    acc=correct/max(1,total)
    print(f"[RPM-Mini] eval accuracy = {acc:.3f}")
    return acc, curve

# ================== ARC-Mini ==================
class ARCMiniDataset(torch.utils.data.Dataset):
    """
    ARCの代表規則を簡略化：majority-color / component-count / row-parity-flip
    2つの入出力例を暗黙にvisionへ埋め込み、3つ目の正解出力を3択で問う。
    """
    def __init__(self, n=6000, grid=6, k_colors=3, device="cpu", noise=0.05, mixer=None, seed=1):
        super().__init__()
        self.device=device; self.noise=noise; self.grid=grid; self.k=k_colors
        random.seed(seed); torch.manual_seed(seed)
        self.mixer = mixer or FrozenObsMixer()
        self.samples=[]
        for _ in range(n):
            rule = random.choice(["MAJ","CCNT","PARITY"])
            def rand_grid():
                return torch.randint(0,self.k,(grid,grid),device=device)

            def apply(g):
                if rule=="MAJ":
                    # 全体で最多色に塗りつぶし
                    vals,counts = torch.unique(g, return_counts=True)
                    c = vals[torch.argmax(counts)].item()
                    return torch.full_like(g, int(c))
                elif rule=="CCNT":
                    # 1の連結成分数を保つ(粗い：1色以外は0)
                    bin = (g==1).float()
                    # ここでは近似として総数をそのまま別の場所に散布
                    ones=int(bin.sum().item())
                    out=torch.zeros_like(g)
                    idx=torch.randperm(g.numel(), device=device)[:ones]
                    out.view(-1)[idx]=1
                    return out.long()
                else:  # PARITY
                    # 偶数行:そのまま、奇数行:色を反転(mod k)
                    out=g.clone()
                    out[1::2]=(self.k-1 - out[1::2])%self.k
                    return out

            # 2つの例 (in_i -> out_i)
            in1,in2 = rand_grid(), rand_grid()
            out1,out2 = apply(in1), apply(in2)
            # クエリ入力
            in3 = rand_grid(); out3 = apply(in3)

            def enc(g):
                # 色をone-hot→平均→V_BASEへ落としてFrozenObsMixerで128へ
                oh = F.one_hot(g, num_classes=self.k).float().mean(dim=(0,1))  # (k,)
                v = torch.zeros(V_BASE, device=device); take=min(V_BASE, oh.numel())
                v[:take]=oh[:take]
                V_full,_=self.mixer(v.unsqueeze(0), torch.zeros(1,S_BASE,device=device))
                return V_full.squeeze(0)+self.noise*torch.randn(D_VISION,device=device)

            # 例をvisionに埋める（in,outを平均加算）
            vision = enc(in1)+enc(out1)+enc(in2)+enc(out2)+0.5*enc(in3)

            # 候補作成（正解＋2ダミー）
            correct=enc(out3)
            wrong1 = enc((out3+1)%self.k)  # 色シフト
            wrong2 = enc(out3.roll(shifts=1,dims=0))  # 粗擾乱（意味なしでもOK）
            choices=torch.stack([correct,wrong1,wrong2],0)
            order=torch.randperm(3); y=int((order==0).nonzero()[0])
            choices=choices[order]
            vision = vision + 0.25*choices.mean(0)

            obs={
                "vision": vision,
                "somatosensory": torch.randn(D_SOMATO,device=device)*self.noise,
                "olfaction": torch.randn(D_OLFACT,device=device)*self.noise,
                "auditory":  torch.randn(D_AUDIT, device=device)*self.noise,
                "proprioception": torch.randn(D_PROP, device=device)*self.noise,
            }
            self.samples.append((obs, y))

    def __len__(self): return len(self.samples)
    def __getitem__(self,i): return self.samples[i]


def train_arc_mini(adapter: ModelAdapter, device="cpu", epochs=3, batch_size=128, noise=0.05):
    ds_tr=ARCMiniDataset(6000, device=device, noise=noise)
    ds_ev=ARCMiniDataset(1000, device=device, noise=noise, seed=13)
    dl=torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)
    opt=torch.optim.Adam(adapter.parameters(), lr=1e-3)
    curve=[]; trials=0
    adapter.train()
    for ep in range(1,epochs+1):
        total=cnt=0
        for obs,y in dl:
            trials+=len(y)
            for k in obs: obs[k]=obs[k].to(device)
            y=y.to(device)
            logits,_=adapter(obs)
            loss=F.cross_entropy(logits,y)
            opt.zero_grad(); loss.backward(); opt.step()
            with torch.no_grad():
                acc=(logits.argmax(-1)==y).float().mean().item()
                curve.append((trials,acc))
            total+=loss.item()*y.size(0); cnt+=y.size(0)
        print(f"[ARC-Mini] epoch={ep} loss={total/max(1,cnt):.3f}")
    # eval
    adapter.eval()
    correct=total=0
    with torch.no_grad():
        for obs,y in ds_ev:
            for k in obs: obs[k]=obs[k].unsqueeze(0).to(device)
            y=int(y)
            logits,_=adapter(obs)
            pred=logits.argmax(-1).item()
            correct+=int(pred==y); total+=1
    acc=correct/max(1,total)
    print(f"[ARC-Mini] eval accuracy = {acc:.3f}")
    return acc, curve

# ================== Grid-Path-FirstStep ==================
class GridPathFirstStep(torch.utils.data.Dataset):
    """
    7x7グリッド: S(スタート), G(ゴール), #(障害物)
    ルール: 最短路の最初の一手 (L/F/R) を3択で答える。
    """
    def __init__(self, n=8000, size=7, device="cpu", noise=0.05, mixer=None, seed=2):
        super().__init__()
        self.device=device; self.noise=noise; self.N=size
        random.seed(seed); torch.manual_seed(seed)
        self.mixer = mixer or FrozenObsMixer()
        self.samples=[]
        for _ in range(n):
            grid=torch.zeros(size,size,dtype=torch.long,device=device)
            # S,G配置
            sx,sy = random.randint(0,size-1), random.randint(0,size-1)
            gx,gy = random.randint(0,size-1), random.randint(0,size-1)
            while (gx,gy)==(sx,sy):
                gx,gy = random.randint(0,size-1), random.randint(0,size-1)
            grid[sy,sx]=1; grid[gy,gx]=2
            # 障害物
            for _o in range(random.randint(size//2, size)):
                ox,oy = random.randint(0,size-1), random.randint(0,size-1)
                if (ox,oy) not in [(sx,sy),(gx,gy)]:
                    grid[oy,ox]=3  # '#'

            # 最短路をBFSで探索し、一手目を決める
            from collections import deque as _dq
            dirs=[(1,0), (0,1), (-1,0), (0,-1)]  # E,S,W,N（右手系）
            prev={}; q=_dq([(sx,sy)]); visited={(sx,sy)}
            found=False
            while q:
                x,y=q.popleft()
                if (x,y)==(gx,gy): found=True; break
                for dx,dy in dirs:
                    nx,ny=x+dx,y+dy
                    if 0<=nx<size and 0<=ny<size and (nx,ny) not in visited and grid[ny,nx]!=3:
                        visited.add((nx,ny)); prev[(nx,ny)]=(x,y); q.append((nx,ny))
            if not found:
                # 経路なし→ランダム方角（学習上はノイズだがロバスト性テストになる）
                step_label=random.randint(0,2)
            else:
                # ゴールから戻って1手目を抽出
                path=[]; cur=(gx,gy)
                while cur!=(sx,sy):
                    path.append(cur); cur=prev[cur]
                path=path[::-1]
                first=path[0] if path else (gx,gy)
                dx,dy = first[0]-sx, first[1]-sy
                # 現在の朝向: 東(1,0)とする。dx,dy→L/F/R
                mapping={(1,0):1,(0,1):2,(-1,0):3,(0,-1):0}  # to dir idx
                forward=(1,0)  # 基準
                # dir to label: 左=0, 前=1, 右=2
                if   (dx,dy)==forward: step_label=1
                elif (dx,dy)==(0,-1): step_label=0  # 上は左に相当（基準から見て）
                elif (dx,dy)==(0,1):  step_label=2  # 下は右
                else: step_label=0 if random.random()<0.5 else 2

            # 視覚エンコード（one-hot平均→FrozenObsMixer→128）
            # 0:空,1:S,2:G,3:# のone-hot平均で荒い地図表現
            oh=F.one_hot(grid, num_classes=4).float().mean(dim=(0,1))  # (4,)
            v=torch.zeros(V_BASE,device=device); take=min(V_BASE,oh.numel())
            v[:take]=oh[:take]
            V_full,_=self.mixer(v.unsqueeze(0), torch.zeros(1,S_BASE,device=device))
            vision=V_full.squeeze(0)+self.noise*torch.randn(D_VISION,device=device)

            obs={
                "vision": vision,
                "somatosensory": torch.randn(D_SOMATO,device=device)*self.noise,
                "olfaction": torch.randn(D_OLFACT,device=device)*self.noise,
                "auditory":  torch.randn(D_AUDIT, device=device)*self.noise,
                "proprioception": torch.randn(D_PROP, device=device)*self.noise,
            }
            self.samples.append((obs, step_label))

    def __len__(self): return len(self.samples)
    def __getitem__(self,i): return self.samples[i]


def train_grid_firststep(adapter: ModelAdapter, device="cpu", epochs=3, batch_size=128, noise=0.05):
    ds_tr=GridPathFirstStep(8000, device=device, noise=noise)
    ds_ev=GridPathFirstStep(1200, device=device, noise=noise, seed=11)
    dl=torch.utils.data.DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True)
    opt=torch.optim.Adam(adapter.parameters(), lr=1e-3)
    curve=[]; trials=0
    adapter.train()
    for ep in range(1,epochs+1):
        total=cnt=0
        for obs,y in dl:
            trials+=len(y)
            for k in obs: obs[k]=obs[k].to(device)
            y=y.to(device)
            logits,_=adapter(obs)
            loss=F.cross_entropy(logits,y)
            opt.zero_grad(); loss.backward(); opt.step()
            with torch.no_grad():
                acc=(logits.argmax(-1)==y).float().mean().item()
                curve.append((trials,acc))
            total+=loss.item()*y.size(0); cnt+=y.size(0)
        print(f"[Grid-FirstStep] epoch={ep} loss={total/max(1,cnt):.3f}")
    # eval
    adapter.eval()
    correct=total=0
    with torch.no_grad():
        for obs,y in ds_ev:
            for k in obs: obs[k]=obs[k].unsqueeze(0).to(device)
            y=int(y)
            logits,_=adapter(obs)
            pred=logits.argmax(-1).item()
            correct+=int(pred==y); total+=1
    acc=correct/max(1,total)
    print(f"[Grid-FirstStep] eval accuracy = {acc:.3f}")
    return acc, curve


def compute_tal_metrics(curve, criterion=0.85, window=100, budget=2000):
    """
    curve: list of (trial_idx, acc)
    return: TTC, AUC@B, Asy@B
    """
    accs = [a for _,a in curve]
    trials = [t for t,_ in curve]

    # TTC
    ttc = None
    for i in range(window, len(accs)):
        if sum(accs[i-window:i])/window >= criterion:
            ttc = trials[i]
            break

    # AUC@B
    auc = sum(accs[:budget]) / min(budget,len(accs))

    # Asymptote
    asy = sum(accs[-window:]) / min(window,len(accs))

    return ttc, auc, asy


def compute_efficiency(auc, params, trials):
    """Eff@B = AUC@B / (params * trials)"""
    return auc / max(1,(params*trials))


# ===============================================================
# Build & Run
# ===============================================================
def build_models(device="cpu") -> Dict[str, nn.Module]:
    return {
        "1_Cnidarian": CnidarianNerveNet(motor_dim=8).to(device),
        "2_SegmentedRestricted": SegmentedGangliaRestricted(segments=6, motor_per_seg=2).to(device),
        "3_Cephalopod": CephalopodBrainV3(motor_dim=4).to(device),
        "4_Fish":      FishBrainV3(motor_dim=12).to(device),
        "5_Human":     HumanCortexV4(motor_dim=20).to(device)
    }

def run_benchmark(device=None, jelly_epochs=3, rev_steps=1500, batch_size=128, detour_epochs=15):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
    set_seed(0)
    results = {}
    models = build_models(device=device)
    for name, base in models.items():
        params = sum(p.numel() for p in base.parameters())
        init_state = {k: v.detach().clone() for k, v in base.state_dict().items()}

        print("\n" + "="*60)
        print(f"Model: {name}")
        # Jellyfish
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        jelly_acc, jelly_curve = train_hd_jellyfish(agent, device=device, epochs=jelly_epochs, batch_size=batch_size)
        j_ttc, j_auc, j_asy = compute_tal_metrics(jelly_curve, criterion=0.90, window=100, budget=2000)
        j_eff = compute_efficiency(j_auc, params, trials=2000)

        # Reversal
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=2, use_memory=False).to(device)
        pre, post, rev_curve = train_reversal(agent, device=device, steps=rev_steps)
        r_ttc, r_auc, r_asy = compute_tal_metrics(rev_curve, criterion=0.80, window=200, budget=rev_steps)
        r_eff = compute_efficiency(r_auc, params, trials=rev_steps)

        # Detour
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        detour_acc, detour_curve = train_detour(agent, device=device, epochs=detour_epochs, batch_size=batch_size)
        d_ttc, d_auc, d_asy = compute_tal_metrics(detour_curve, criterion=0.85, window=200, budget=3000)
        d_eff = compute_efficiency(d_auc, params, trials=3000)

        # reflex
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        acc, r_curve = train_local_reflex(agent, device=device, epochs=15)
        r_ttc, r_auc, r_asy = compute_tal_metrics(r_curve, criterion=0.85, window=200, budget=3000)
        d_eff = compute_efficiency(d_auc, params, trials=3000)

        results[name] = {
            "jelly_acc": jelly_acc, "jelly_TTC": j_ttc, "jelly_AUC": j_auc, "jelly_Asy": j_asy, "jelly_Eff": j_eff,
            "rev_pre": pre, "rev_post": post, "rev_TTC": r_ttc, "rev_AUC": r_auc, "rev_Asy": r_asy, "rev_Eff": r_eff,
            "detour_acc": detour_acc, "detour_TTC": d_ttc, "detour_AUC": d_auc, "detour_Asy": d_asy, "detour_Eff": d_eff,
            "reflex_acc": acc, "reflex_TTC": r_ttc, "reflex_AUC": r_auc, "reflex_Asy": r_asy, "reflex_Eff": r_eff,
        }

        # RPM-Mini
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        rpm_acc, rpm_curve = train_rpm_mini(agent, device=device, epochs=80, batch_size=batch_size)
        rpm_ttc, rpm_auc, rpm_asy = compute_tal_metrics(rpm_curve, criterion=0.75, window=200, budget=3000)

        # ARC-Mini
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        arc_acc, arc_curve = train_arc_mini(agent, device=device, epochs=80, batch_size=batch_size)
        arc_ttc, arc_auc, arc_asy = compute_tal_metrics(arc_curve, criterion=0.75, window=200, budget=3000)

        # Grid-FirstStep
        base.load_state_dict(init_state, strict=True)
        agent = ModelAdapter(base, n_actions=3, use_memory=False).to(device)
        gpf_acc, gpf_curve = train_grid_firststep(agent, device=device, epochs=50, batch_size=batch_size)
        gpf_ttc, gpf_auc, gpf_asy = compute_tal_metrics(gpf_curve, criterion=0.75, window=200, budget=3000)

        # 結果格納:
        results[name].update({
          "rpm_acc": rpm_acc, "rpm_TTC": rpm_ttc, "rpm_AUC": rpm_auc, "rpm_Asy": rpm_asy,
          "arc_acc": arc_acc, "arc_TTC": arc_ttc, "arc_AUC": arc_auc, "arc_Asy": arc_asy,
          "gpf_acc": gpf_acc, "gpf_TTC": gpf_ttc, "gpf_AUC": gpf_auc, "gpf_Asy": gpf_asy,
        })



    # Summary
    print("\n" + "#"*60)
    print("# Summary: jelly_acc | rev_pre | rev_post| detour_acc")
    print("#"*60)
    for k in results:
        r = results[k]
        print(k)
        print(r)
    return results

# ===================== Run =====================
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    results = run_benchmark(device=device, jelly_epochs=5, rev_steps=1500, batch_size=128)
