In [1]:
# %% [markdown]
# %%

In [2]:
#@title === 0) Install ===
# - MuJoCo official python api, Ray RLlib, Gymnasium, SpikingJelly
#   참고: MjModel.from_xml_string / XML integrator 옵션 / 마찰계수(접선/비틀림/굴림) / Gymnasium 연동
#   Docs: mujoco.readthedocs.io, Ray RLlib rllib-env & examples, SpikingJelly encoding
!pip -q install "mujoco>=3.1" "ray[rllib]>=2.7" "gymnasium>=0.29" "spikingjelly>=0.0.0.0.14" --upgrade

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.5/43.5 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m73.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.1/70.1 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m965.4/965.4 kB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m220.9/220.9 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# 새 셀에서 실행
!pip -q uninstall -y ray || true
# 최신 안정판(예: 2.49.x 계열)으로 고정 설치 – RLlib 포함
!pip -q install -U "ray[rllib]==2.49.2"

# 재시작 없이 바로 import하면 꼬일 수 있으니, 세션 강제 재시작
import os, signal, time
print("Restarting runtime to finalize Ray install ...")
os.kill(os.getpid(), 9)


In [1]:
!pip install gputil

Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: gputil
  Building wheel for gputil (setup.py) ... [?25l[?25hdone
  Created wheel for gputil: filename=GPUtil-1.4.0-py3-none-any.whl size=7392 sha256=4f7381a66fd6583b312cb3a1f458a3e6da952366a921a7b80b82d8737ff0d849
  Stored in directory: /root/.cache/pip/wheels/92/a8/b7/d8a067c31a74de9ca252bbe53dea5f896faabd25d55f541037
Successfully built gputil
Installing collected packages: gputil
Successfully installed gputil-1.4.0


In [2]:
#@title === 1) Imports & Utils ===
import math, textwrap, numpy as np
import mujoco as mj
from gymnasium import Env, spaces
from dataclasses import dataclass
import ray
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.env.env_context import EnvContext

# (Optional) SpikingJelly: 간단 레이트→포아송 인코더 예시용
try:
    from spikingjelly.activation_based import functional as sjF
except Exception:
    sjF = None

DT = 1e-4   # 0.1 ms
V  = 50.0   # m/s (signal propagation speed)
PLATE_HALF = 0.15     # 30 cm half-size in meters
SITES_N = 10          # 10 x 10 = 100 slots
BALL_R  = 0.015       # 1.5 cm
BALL_M  = 0.03        # 30 g

# ---- 상단에 추가/수정 ----
N_BINS = 2
BINS   = np.array([0.0, 1.0], dtype=np.float32)  # {0, 1}

def idx_to_u4(idx, n_bins=N_BINS, bins=BINS):
    u = np.empty(4, dtype=np.float32)
    for k in range(4):
        u[k] = bins[idx % n_bins]
        idx //= n_bins
    # [U, R, D, L] 순서 유지
    return u


# ---- Delay line for sensor pipeline ----
class DelayLine:
    def __init__(self, max_steps, n_channels):
        self.buf = np.zeros((max_steps+1, n_channels), dtype=np.float32)
        self.ptr = 0
        self.max_steps = max_steps
        self.n = n_channels
    def push(self, x_t):  # x_t shape [n_channels]
        self.buf[self.ptr, :] = x_t
        self.ptr = (self.ptr + 1) % (self.max_steps+1)
    def read_delayed(self, ks):  # ks shape [n_channels], per-channel delay steps
        # gather diagonal indices with wrap-around
        idx = (self.ptr - 1 - ks) % (self.max_steps+1)
        return self.buf[idx, np.arange(self.n)]

def manhattan_delay_steps(x, y, dt=DT, v=V):
    d_man = abs(x) + abs(y)
    return int(np.ceil(d_man / (v * dt)))

def make_site_grid(n=SITES_N, half=PLATE_HALF, z=0.006):
    xs = np.linspace(-half, half, n)
    ys = np.linspace(-half, half, n)
    pts = []
    for i, y in enumerate(ys):
        for j, x in enumerate(xs):
            pts.append((x, y, z, f"sn_i{i}_j{j}"))
    return pts  # [(x,y,z,name) * 100]

def site_xml_lines(pts):
    lines = []
    for (x,y,z,name) in pts:
        lines.append(
            f'<site name="{name}" pos="{x:.4f} {y:.4f} {z:.4f}" '
            f'size="0.002" type="sphere" rgba="0.2 0.8 0.2 0.5"/>'
        )
    return "\n      ".join(lines)

# 4 edge actuators -> 2 hinge torques
def edge4_to_torque2(u4, w=(1.,1.,1.,1.), max_tau=(0.2, 0.2)):
    uU, uR, uD, uL = u4
    wU, wR, wD, wL = w
    tau_y = (wU*uU - wD*uD) * max_tau[1]
    tau_x = (wR*uR - wL*uL) * max_tau[0]
    return float(np.clip(tau_x, -max_tau[0], max_tau[0])), float(np.clip(tau_y, -max_tau[1], max_tau[1]))

# === 2) MJCF(XML) build (30cm square plate + ball, 2 hinge motors) ===
# - timestep=0.1 ms, integrator=RK4
# - friction triplet(tangential, torsional, rolling) tuned to avoid infinite rolling
#   (굴림마찰 항목은 문서에 설명됨)
sites = make_site_grid()
xml = f"""
<mujoco model="tilt_plate">
  <compiler angle="degree" inertiafromgeom="true"/>
  <option timestep="{DT:.7f}" gravity="0 0 -9.81" integrator="RK4"/>
  <default>
    <geom  condim="6" margin="0.001" solimp="0.9 0.95 0.001" solref="0.002 1"/>
    <default class="plate">
      <geom type="box" friction="0.8 0.003 0.001" rgba="0.8 0.8 0.85 1"/>
    </default>
    <default class="ball">
      <geom type="sphere" friction="0.9 0.005 0.002" rgba="0.9 0.3 0.3 1"/>
    </default>
    <joint armature="0.002" damping="0.1" limited="true"/>
    <motor gear="1.0" ctrllimited="true" ctrlrange="-1.0 1.0"/>
  </default>
  <worldbody>
    <body name="plate_base" pos="0 0 0">
      <joint name="hinge_x" type="hinge" axis="1 0 0" range="-5 5"/>
      <joint name="hinge_y" type="hinge" axis="0 1 0" range="-5 5"/>
      <geom name="plate_geom" class="plate" size="{PLATE_HALF} {PLATE_HALF} 0.005" mass="1.0"/>
      {site_xml_lines(sites)}
    </body>
    <body name="ball" pos="0 0 {BALL_R+0.01:.4f}">
      <freejoint name="ball_free"/>
      <geom name="ball_geom" class="ball" size="{BALL_R}" mass="{BALL_M}"/>
    </body>
  </worldbody>
  <actuator>
    <motor name="mx" joint="hinge_x" gear="1"/>
    <motor name="my" joint="hinge_y" gear="1"/>
  </actuator>
</mujoco>
""".strip()

# Build model from XML string (official API)
model = mj.MjModel.from_xml_string(xml)  # docs show this factory method
data  = mj.MjData(model)

# Precompute per-site manhattan delay (receiver @ center (0,0))
site_xy = [(x, y) for (x,y,_,_) in sites]
ks = np.array([manhattan_delay_steps(x, y) for (x,y) in site_xy], dtype=np.int32)
delay_buf = DelayLine(max_steps=int(ks.max()), n_channels=len(sites))

print("Model OK. Sites:", len(sites), "Max delay steps:", ks.max())

# === 3) Minimal Gymnasium-style Env wrapping MuJoCo ===
# - obs: [ball_x, ball_y, ball_vx, ball_vy, hinge_x, hinge_y, hinge_xvel, hinge_yvel] + delayed 100-d sensor
# - act: 4-d edge actuator in [-1,1], mapped to 2 hinge torques
# - reward: -distance - 0.1*speed - 0.001*||u||^2
@dataclass
class PlateConfig:
    dt: float = DT
    substeps: int = 1
    max_steps: int = 3000   # 0.3 s per episode at 0.1 ms
    tau_max_x: float = 0.2
    tau_max_y: float = 0.2

# 맨 위에(선택): from ray.rllib.env.env_context import EnvContext

class TiltPlateEnv(Env):
    metadata = {"render_modes": []}

    def __init__(self, env_config=None):
        # 1) EnvContext/dict → 평범한 dict로 변환
        try:
            from ray.rllib.env.env_context import EnvContext  # optional import
            if isinstance(env_config, EnvContext):
                env_config = dict(env_config)
        except Exception:
            pass
        if env_config is None:
            env_config = {}

        # 2) 기본값과 merge
        self.cfg = PlateConfig(
            dt       = env_config.get("dt", DT),        # 0.1 ms
            substeps = env_config.get("substeps", 1),
            max_steps= env_config.get("max_steps", 3000),
            tau_max_x= env_config.get("tau_max_x", 0.2),
            tau_max_y= env_config.get("tau_max_y", 0.2),
        )

        # ---- 기존 초기화 그대로 (model/data/delay 등) ----
        self.model = mj.MjModel.from_xml_string(xml)
        self.data  = mj.MjData(self.model)
        self.n_sites = len(sites)
        self.ks = np.array([manhattan_delay_steps(x, y) for (x,y) in site_xy], dtype=np.int32)
        self.delay = DelayLine(int(self.ks.max()), self.n_sites)

        # ✅ DQN용 이산 액션(비음수 {0,1}^4 → 16개)
        self.action_space = spaces.Discrete(N_BINS ** 4)  # N_BINS=2, BINS=[0.,1.]
        high = np.full((8 + self.n_sites,),  np.inf, dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        self.step_count = 0
        self._reset_ball()

    def _reset_ball(self):
        # randomize ball initial pos near center
        self.data.qpos[:] = 0.0
        self.data.qvel[:] = 0.0
        # freejoint qpos layout: [x,y,z, qw,qx,qy,qz]
        self.data.qpos[0] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[1] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[2] = BALL_R + 0.005
        # clear delays
        self.delay = DelayLine(int(self.ks.max()), self.n_sites)

    def _sense_now(self):
        # simple radial basis around ball -> site signal in [0,1]
        bx, by = self.data.qpos[0], self.data.qpos[1]
        sig = []
        for (x,y) in site_xy:
            d = abs(bx - x) + abs(by - y)  # Manhattan for locality
            val = math.exp(-d / 0.05)      # 5 cm falloff
            sig.append(val)
        return np.array(sig, dtype=np.float32)

    def _obs(self):
        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        # hinges are last 2 joints (x,y)
        # joint positions follow hinge_x, hinge_y after freejoint (7 dofs)
        # mj identifies by name safer:
        jx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_x")
        jy = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_y")
        # qpos/qvel indexing known from model.jnt_qposadr
        hx = self.data.qpos[self.model.jnt_qposadr[jx]]
        hy = self.data.qpos[self.model.jnt_qposadr[jy]]
        hvx = self.data.qvel[self.model.jnt_dofadr[jx]]
        hvy = self.data.qvel[self.model.jnt_dofadr[jy]]
        # delayed sensors:
        delayed = self.delay.read_delayed(self.ks)
        return np.concatenate([[bx,by,vx,vy,hx,hy,hvx,hvy], delayed]).astype(np.float32)

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self._reset_ball()
        self.step_count = 0
        # push a few zeros to delay line
        self.delay.push(np.zeros(self.n_sites, dtype=np.float32))
        return self._obs(), {}

    # step() 안
    def step(self, action):
        u4 = idx_to_u4(int(action))   # {0,1}^4
        tau_x, tau_y = edge4_to_torque2(
            u4, max_tau=(self.cfg.tau_max_x, self.cfg.tau_max_y)
        )
        self.data.ctrl[0] = tau_x
        self.data.ctrl[1] = tau_y
        # 이하 동일...


        # sense current & push into delay buffer
        self.delay.push(self._sense_now())

        # advance physics
        for _ in range(self.cfg.substeps):
            mj.mj_step(self.model, self.data)

        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        r = math.sqrt(bx*bx + by*by)
        speed = math.sqrt(vx*vx + vy*vy)

        # reward shaping: center & slow preferred; small action penalty
        reward = -r - 0.1*speed - 0.001*float(np.sum(np.square(u4)))

        self.step_count += 1
        terminated = self.step_count >= self.cfg.max_steps
        truncated  = False
        return self._obs(), reward, terminated, truncated, {}

# quick sanity: random rollout
env = TiltPlateEnv()
obs, _ = env.reset(seed=42)
for t in range(50):
    a = env.action_space.sample()
    obs, rew, term, trunc, _ = env.step(a)
    if term or trunc: break
print("Random rollout OK. step:", t+1, "obs_dim:", obs.shape)

# === 4) (Optional) SpikingJelly Rate/Poisson Encoding Hook ===
# - obs ∈ R^(108). 간단히 [0,1]로 정규화→포아송 스파이크로 변환 가능.
# - 아래는 자리만들기(실험 때 연결). 자세한 인코딩은 SpikingJelly 튜토리얼 참고.
def poisson_spike(obs, rate_scale=50.0, T=5, dt=1e-4):
    x = np.asarray(obs, dtype=np.float32)

    # peak-to-peak 범위 (NumPy 2.0+: ndarray.ptp() 대신 np.ptp(x) 사용)
    rng = float(np.ptp(x))  # == x.max() - x.min()

    # 모든 값이 같거나 비정상 값인 경우 대비
    if not np.isfinite(rng) or rng < 1e-12:
        x_norm = np.zeros_like(x, dtype=np.float32)
    else:
        x_norm = (x - float(x.min())) / (rng + 1e-6)

    lam = x_norm * rate_scale * dt  # 기대 스파이크율
    # T 스텝 동안 포아송 샘플링
    return (np.random.rand(T, x.shape[0]) < lam).astype(np.float32)


spk = poisson_spike(obs, T=5)
print("Poisson spikes shape:", spk.shape)

# === 5) RLlib DQN minimal run (tiny) ===
# - Gymnasium Env을 RLlib에 직접 넘김.
# - 빠른 데모: train_iters=1 (실전은 늘려야 함)
# %%
ray.shutdown()
ray.init(ignore_reinit_error=True, include_dashboard=False)


cfg = (
    DQNConfig()
    .environment(env=TiltPlateEnv, env_config={
            "tau_max_x": 0.2,
            "tau_max_y": 0.2,
            # 필요하면 dt/substeps 등도 여기서 조정
        },)  # ← 클래스로!
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=True)
    .env_runners(num_env_runners=0)               # ← 예전 rollouts 대체
)

algo = cfg.build()
res = algo.train()
print({
    "len_mean": res.get("env_runners", {}).get("episode_len_mean"),
    "ret_mean": res.get("env_runners", {}).get("episode_return_mean"),
})


print("DQN one-iter result keys:", {k: res[k] for k in ["episode_len_mean", "episode_reward_mean"] if k in res})
algo.stop()
ray.shutdown()


Model OK. Sites: 100 Max delay steps: 60
Random rollout OK. step: 50 obs_dim: (108,)
Poisson spikes shape: (5, 108)


2025-09-22 02:35:35,272	INFO worker.py:1951 -- Started a local Ray instance.


{'len_mean': None, 'ret_mean': None}
DQN one-iter result keys: {}


In [3]:
#@title @dataclass SimulationConfig

@dataclass
class SimulationConfig:
    # --- 기본 시뮬레이션 ---
    dt_ms: float = 0.1                              # [ms]
    dt_sim : float = dt_ms / 1000
    plane_size: float = 0.30                        # [m] 한 변 길이 (±0.15 m)
    goal_r: float = 0.005                           # [m] 목표 반경
    tau_pre_ms: float = 5.0
    tau_post_ms: float = 5.0
    tau_e_ms: float = 50.0
    eta: float = 1e-3
    max_dw_per_step: float = 1e-3
    scale_every: int = 100
    use_mf_collat: bool = True
    mf_collat_gain: float = 0.3
    seed: int = 42

    # --- SNN / 센서 ---
    n_mf: int = 100                                 # 센서(모스피버, etc.) 개수
    sensor_slots_xy: int = 10                       # 10x10 = 100 슬롯
    conduction_vel: float = 50.0                    # [m/s] 신경 신호 속도
    use_manhattan_delay: bool = True                # 맨해튼 거리를 쓸 건가요?

    # --- 뉴런 계층 ---
    n_mf: int = 100
    n_grc: int = 1024
    n_pkj: int = 64
    n_motor: int = 4                                # +X, -X, +Y, -Y

    # --- 물리/액추에이터(고정 파라미터) ---
    plate_mass: float = 0.50                        # [kg] 예: 0.5 kg
    actuator_torque_scale: float = 0.05             # [N·m] 스칼라 가중치
    friction_coeff: float = 0.05                    # 무차원/단순 감쇠항
    gravity: float = 9.8
    # 난수 고정(슬롯 샘플링 재현성)
    seed: int = 42

    #inputsimconfig

    # 시뮬레이션
    # 발화(rate-based Poisson)
    base_rate_hz: float = 400.0
    rate_scale_weight: float = 40.0      # 가중: lam = base + k*weight_g
    sigma_ratio: float = 1.0             # σ = ratio * ball_radius
    # 판 기울기 제한
    max_tilt_deg: float = 5.0
    # 비디오/시각화
    video_fps: int = 30

    # --- 4-엣지 액추에이터 매핑(가상) ---
    edge_ctrl_min: float = 0.0   # 각 채널 입력 하한(수축 0~1 가정)
    edge_ctrl_max: float = 1.0   # 각 채널 입력 상한
    edge_gain_deg: float = 10.0   # (우-좌) 또는 (상-하) 1.0 차이가 만드는 기울기[deg]
    edge_x_sign: float = +1.0    # 부호 교정(축 정의에 따라 필요시 -1)
    edge_y_sign: float = +1.0

    # 안정화
    settle_steps: int = 500

In [4]:
#@title Cerebellar Modification

# ---- 4) 소뇌형 SNN (SpikingJelly 캡슐화) -------------------------------------
from typing import Optional, Tuple, Protocol
import torch
import torch.nn as nn
from math import ceil
from spikingjelly.activation_based import layer, neuron, functional
from spikingjelly.activation_based import learning, surrogate

class ISNN(Protocol):
    def reset(self) -> None: ...
    def forward(self, mf_spikes: np.ndarray) -> np.ndarray: ... # -> motor spikes (4,)
    def learn(self, cf_signal: float) -> None: ...

class CerebellarNet(ISNN):
    """
    내부 구현은 SpikingJelly activation_based(IFNode, Linear) 사용.  :contentReference[oaicite:8]{index=8}
    forward:  MF -> GrC -> PkG -> Motor(4) 스파이크
    learn:    PF(GrC)-Motor 가중치에 3-요소 규칙(eligibility × CF) 적용
    """
    def __init__(self, cfg: SimulationConfig, device: Optional[str] = None):
        self.cfg = cfg

        self.torch = torch; self.nn = nn;
        self.sj_func = functional       #spikingjelly functional
        self.dev = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # ---------- Fixed expansion: MF -> GrC ----------
        self.mf2grc = layer.Linear(cfg.n_mf, cfg.n_grc, bias=False).to(self.dev)
        with self.torch.no_grad():
            W = self.torch.zeros(cfg.n_grc, cfg.n_mf, device=self.dev)
            mask = (self.torch.rand_like(W) < 0.1).float()
            W += mask
            W /= (W.sum(dim=1, keepdim=True) + 1e-6)
            self.mf2grc.weight.copy_(W)
        for p in self.mf2grc.parameters(): p.requires_grad_(False)
        self.grc = neuron.IFNode(v_threshold=1.0, v_reset=0.0, detach_reset=True).to(self.dev)

        # ---------- Trainable plastic site: GrC(PF) -> PKJ ----------
        self.grc2pkj = layer.Linear(cfg.n_grc, cfg.n_pkj, bias=False).to(self.dev)
        self.nn.init.zeros_(self.grc2pkj.weight)  # start neutral; learn via CF-gated 3-factor
        self.pkj = neuron.IFNode(v_threshold=1.0, v_reset=0.0, detach_reset=True).to(self.dev)

        tau_pre_steps  = float(self.cfg.tau_pre_ms / self.cfg.dt_ms)
        tau_post_steps = float(self.cfg.tau_post_ms / self.cfg.dt_ms)
        def _f_w(x: torch.Tensor) -> torch.Tensor:
            return torch.clamp(x, -1.5, 1.5)
        self.stdp_pkj = learning.STDPLearner(
            step_mode='s',
            synapse=self.grc2pkj,    # 학습할 Linear
            sn=self.pkj,             # post 뉴런층
            tau_pre=tau_pre_steps,
            tau_post=tau_post_steps,
            f_pre=_f_w,
            f_post=_f_w
        )
        self.opt_stdp = torch.optim.SGD(self.grc2pkj.parameters(), lr=self.cfg.eta, momentum=0.0)

        # ---------- CF input: strong drive to PKJ (one-to-one by default) ----------
        self.cf2pkj = layer.Linear(cfg.n_pkj, cfg.n_pkj, bias=False).to(self.dev)
        with self.torch.no_grad():
            self.cf2pkj.weight.copy_(self.torch.eye(cfg.n_pkj, device=self.dev) * 3.0)  # strong drive
        for p in self.cf2pkj.parameters(): p.requires_grad_(False)

        # ---------- DN/motor: receives (-) PKJ, (+) MF collaterals (optional) ----------
        self.pkj2mo = layer.Linear(cfg.n_pkj, cfg.n_motor, bias=False).to(self.dev)
        with self.torch.no_grad():
            # inhibitory mapping: negative weights, broadly tuned
            Wp = - self.torch.randn(cfg.n_motor, cfg.n_pkj, device=self.dev)
            Wp /= (Wp.std(dim=1, keepdim=True) + 1e-6)
            self.pkj2mo.weight.copy_(Wp * 0.2)
        for p in self.pkj2mo.parameters(): p.requires_grad_(False)

        if cfg.use_mf_collat:
            self.mf2mo = layer.Linear(cfg.n_mf, cfg.n_motor, bias=False).to(self.dev)
            with self.torch.no_grad():
                Wc = self.torch.zeros(cfg.n_motor, cfg.n_mf, device=self.dev)
                mask = (self.torch.rand_like(Wc) < 0.2).float()
                Wc += mask
                Wc /= (Wc.sum(dim=1, keepdim=True) + 1e-6)
                self.mf2mo.weight.copy_(Wc * cfg.mf_collat_gain)
            for p in self.mf2mo.parameters(): p.requires_grad_(False)
        else:
            self.mf2mo = None

        self.mo  = neuron.IFNode(v_threshold=1.0, v_reset=0.0, detach_reset=True).to(self.dev)

        # ---------- Eligibility traces for PF->PKJ ----------
        self.pre_tr  = np.zeros(cfg.n_grc, dtype=np.float32)     # PF (pre)
        self.post_tr = np.zeros(cfg.n_pkj, dtype=np.float32)     # PKJ (post)
        self.E = np.zeros((cfg.n_grc, cfg.n_pkj), dtype=np.float32)

        self.pre_decay  = math.exp(-cfg.dt_ms/cfg.tau_pre_ms)
        self.post_decay = math.exp(-cfg.dt_ms/cfg.tau_post_ms)
        self.e_decay    = math.exp(-cfg.dt_ms/cfg.tau_e_ms)

        # cache last spikes for learning step
        self._last = {}

    def reset(self) -> None:
        self.sj_func.reset_net(self.mf2grc); self.sj_func.reset_net(self.grc)
        self.sj_func.reset_net(self.grc2pkj); self.sj_func.reset_net(self.pkj)
        self.sj_func.reset_net(self.pkj2mo);  self.sj_func.reset_net(self.mo)
        if self.mf2mo is not None: self.sj_func.reset_net(self.mf2mo)
        self.pre_tr[:] = 0; self.post_tr[:] = 0; self.E[:] = 0
        self._last.clear()

    def forward(self, mf_spikes: np.ndarray, cf_spikes: Optional[np.ndarray] = None) -> np.ndarray:
        """
        mf_spikes : (N_MF,)  binary
        cf_spikes : (N_PKJ,) binary or rate-like; if None, zeros
        returns   : (N_MOTOR,) binary spikes
        """
        t = self.torch
        x_mf = t.from_numpy(mf_spikes)[None,:].to(self.dev)
        x_cf = None
        if cf_spikes is not None:
            x_cf = t.from_numpy(cf_spikes)[None,:].to(self.dev)

        grc_spk = self.grc(self.mf2grc(x_mf))
        pkj_in  = self.grc2pkj(grc_spk)
        if x_cf is not None:
            pkj_in = pkj_in + self.cf2pkj(x_cf)   # CF drives PKJ strongly
        pkj_spk = self.pkj(pkj_in)

        mo_in = self.pkj2mo(pkj_spk)              # inhibition by PKJ
        if (self.mf2mo is not None):
            mo_in = mo_in + self.mf2mo(x_mf)      # MF collateral excitation
        mo_spk = self.mo(mo_in)

        # cache last spikes for learning
        self._last['grc'] = grc_spk.detach().cpu().numpy()[0]
        self._last['pkj'] = pkj_spk.detach().cpu().numpy()[0]
        self._last['mo']  = mo_spk.detach().cpu().numpy()[0]
        self._last['cf']  = (cf_spikes.copy() if cf_spikes is not None else None)
        return self._last['mo']

    def learn(self, cf_signal: Optional[np.ndarray] = None) -> None:
        """
        CF-gated three-factor at PF->PKJ:
          ΔW_ij ∝ η * (CF_j) * E_ij
        cf_signal:
          - None or scalar → uses scalar for all PKJ units
          - (N_PKJ,) vector → per-PKJ gating (recommended if you pass cf_spikes)
        """
        grc_out = self._last['grc']; pkj_out = self._last['pkj']
        # 1) update traces
        self.pre_tr  = self.pre_tr  * self.pre_decay  + grc_out
        self.post_tr = self.post_tr * self.post_decay + pkj_out
        self.E = self.E * self.e_decay + (np.outer(self.pre_tr, pkj_out) - np.outer(grc_out, self.post_tr))

        # 2) CF gating (vector per PKJ if available)
        if cf_signal is None:
            if self._last.get('cf') is not None:
                cf = self._last['cf']           # (N_PKJ,)
            else:
                cf = 1.0                        # scalar gate
        else:
            cf = cf_signal                      # scalar or (N_PKJ,)

        # PATCH: in CerebellarNet.learn(), just before applying to weight:
        with self.torch.no_grad():
            E_t = self.torch.from_numpy(self.E).to(self.dev)     # [N_GRC, N_PKJ]
            if np.isscalar(cf):
                dW = self.cfg.eta * float(cf) * E_t
            else:
                cf_vec = self.torch.from_numpy(np.asarray(cf, dtype=np.float32)).to(self.dev)  # [N_PKJ]
                dW = self.cfg.eta * (E_t * cf_vec)   # broadcast on last dim

            # --- NEW: per-step update clamp (stability) ---
            dW.clamp_(-self.cfg.max_dw_per_step, self.cfg.max_dw_per_step)

            self.grc2pkj.weight += dW.T              # [N_PKJ, N_GRC]
            self.grc2pkj.weight.clamp_(-1.5, 1.5)

            # --- NEW: periodic synaptic scaling (row L2 ≤ 1) ---
            if not hasattr(self, "_scale_k"):
                self._scale_k = 0
            self._scale_k += 1
            if (self._scale_k % self.cfg.scale_every) == 0:
                W = self.grc2pkj.weight
                rownorm = W.norm(p=2, dim=1, keepdim=True).clamp_min(1.0)
                W.mul_(1.0 / rownorm)

In [5]:
# ==== SNN <-> Env 어댑터 ====

# 0) MF 인코더: obs -> Poisson spikes (한 스텝분)
def encode_mf_spikes(obs, rate_scale=50.0, dt=DT):
    x = np.asarray(obs, dtype=np.float32)
    # 간단 정규화: [0,1]
    rng = float(np.ptp(x))
    x = (x - float(x.min())) / (rng + 1e-6) if rng >= 1e-12 else np.zeros_like(x, dtype=np.float32)
    lam = x * rate_scale * dt
    return (np.random.rand(x.shape[0]) < lam).astype(np.float32)  # (N_MF,)

# 1) motor(4) 스파이크 -> 엣지 명령 u4 (비음수만: {0,1})
def motor_spikes_to_u4(mo_spk):
    # 이미 binary니까 그대로 uU,uR,uD,uL
    return mo_spk.astype(np.float32)  # shape (4,), 값 {0,1}

# 2) CF 게이트 설계: 거리 감소는 "좋음(+)", 증가 "나쁨(-)"
#    - 스칼라 CF로 간단히 시작 (PKJ 전체에 동일 게이트)
def cf_from_transition(prev_r, curr_r, k_pos=+1.0, k_neg=-1.0):
    dr = curr_r - prev_r
    if dr < 0:
        return k_pos  # 가까워지면 강화(+)
    elif dr > 0:
        return k_neg  # 멀어지면 약화(-)
    else:
        return 0.0

# 3) 한 에피소드 학습 루프 (SNN만)
def train_one_episode_snn(env, cereb, steps=2000, rate_scale=50.0):
    obs, _ = env.reset()
    cereb.reset()

    # prev_r 초기화
    bx, by = env.data.qpos[0], env.data.qpos[1]
    prev_r = float(np.hypot(bx, by))
    ep_ret = 0.0

    for t in range(steps):
        # --- 인코딩 & forward ---
        mf_spk = encode_mf_spikes(obs, rate_scale=rate_scale, dt=DT)               # (N_MF,)
        mo_spk = cereb.forward(mf_spk, cf_spikes=None)                             # (4,)

        # --- 행동 적용 ---
        u4 = motor_spikes_to_u4(mo_spk)                                           # {0,1}^4
        tau_x, tau_y = edge4_to_torque2(u4, max_tau=(env.cfg.tau_max_x, env.cfg.tau_max_y))
        env.data.ctrl[0] = tau_x; env.data.ctrl[1] = tau_y

        # --- 환경 스텝 ---
        obs, reward, term, trunc, _ = env.step(int(0))  # Gym checker를 통과하려면 더미 action 필요 없음
        # ↑ 우리 env.step(action) 인터페이스가 DQN/Discrete 기준이라면,
        #   여기선 이미 data.ctrl에 토크를 넣었으니, env.step은 "no-op" action으로 호출해 한 스텝 전진만 수행.
        #   (필요하다면 step_internal() 같은 헬퍼를 분리해도 됨)

        # --- CF 계산 & 학습 ---
        bx, by = env.data.qpos[0], env.data.qpos[1]
        curr_r = float(np.hypot(bx, by))
        cf = cf_from_transition(prev_r, curr_r, k_pos=+1.0, k_neg=-1.0)
        cereb.learn(cf_signal=cf)

        prev_r = curr_r
        ep_ret += reward
        if term or trunc:
            break
    return ep_ret, t+1


In [6]:
# ==== 0) Imports & constants ====
import math, numpy as np, textwrap
import mujoco as mj
from gymnasium import Env, spaces
from dataclasses import dataclass
from typing import Optional
# Ray/RLlib는 뒤의 가이드에서 사용 (여기선 임포트 지연 가능)

# 시뮬 타임/지연 설정
DT = 1e-4   # 0.1 ms
V  = 50.0   # m/s
PLATE_HALF = 0.15
SITES_N = 10
BALL_R  = 0.015
BALL_M  = 0.03

# ==== 1) Delay line + grid sites ====
class DelayLine:
    def __init__(self, max_steps, n_channels):
        self.buf = np.zeros((max_steps+1, n_channels), dtype=np.float32)
        self.ptr = 0
        self.max_steps = max_steps
        self.n = n_channels
    def push(self, x_t):
        self.buf[self.ptr, :] = x_t
        self.ptr = (self.ptr + 1) % (self.max_steps+1)
    def read_delayed(self, ks):
        idx = (self.ptr - 1 - ks) % (self.max_steps+1)
        return self.buf[idx, np.arange(self.n)]

def manhattan_delay_steps(x, y, dt=DT, v=V):
    return int(np.ceil((abs(x)+abs(y)) / (v*dt)))

def make_site_grid(n=SITES_N, half=PLATE_HALF, z=0.006):
    xs = np.linspace(-half, half, n)
    ys = np.linspace(-half, half, n)
    pts = []
    for i, y in enumerate(ys):
        for j, x in enumerate(xs):
            pts.append((x, y, z, f"sn_i{i}_j{j}"))
    return pts

def site_xml_lines(pts):
    lines = []
    for (x,y,z,name) in pts:
        lines.append(
            f'<site name="{name}" pos="{x:.4f} {y:.4f} {z:.4f}" '
            f'size="0.002" type="sphere" rgba="0.2 0.8 0.2 0.5"/>'
        )
    return "\n      ".join(lines)

# 4-edge → 2-hinge torque
def edge4_to_torque2(u4, w=(1.,1.,1.,1.), max_tau=(0.2,0.2)):
    uU, uR, uD, uL = u4
    wU, wR, wD, wL = w
    tau_y = (wU*uU - wD*uD) * max_tau[1]
    tau_x = (wR*uR - wL*uL) * max_tau[0]
    return float(np.clip(tau_x, -max_tau[0], max_tau[0])), float(np.clip(tau_y, -max_tau[1], max_tau[1]))

# ==== 2) MJCF XML build ====
sites = make_site_grid()
xml = f"""
<mujoco model="tilt_plate">
  <compiler angle="degree" inertiafromgeom="true"/>
  <option timestep="{DT:.7f}" gravity="0 0 -9.81" integrator="RK4"/>
  <default>
    <geom  condim="6" margin="0.001" solimp="0.9 0.95 0.001" solref="0.002 1"/>
    <default class="plate">
      <geom type="box" friction="0.8 0.003 0.001" rgba="0.8 0.8 0.85 1"/>
    </default>
    <default class="ball">
      <geom type="sphere" friction="0.9 0.005 0.002" rgba="0.9 0.3 0.3 1"/>
    </default>
    <joint armature="0.002" damping="0.1" limited="true"/>
    <motor gear="1.0" ctrllimited="true" ctrlrange="-1.0 1.0"/>
  </default>
  <worldbody>
    <body name="plate_base" pos="0 0 0">
      <joint name="hinge_x" type="hinge" axis="1 0 0" range="-5 5"/>
      <joint name="hinge_y" type="hinge" axis="0 1 0" range="-5 5"/>
      <geom name="plate_geom" class="plate" size="{PLATE_HALF} {PLATE_HALF} 0.005" mass="1.0"/>
      {site_xml_lines(sites)}
    </body>
    <body name="ball" pos="0 0 {BALL_R+0.01:.4f}">
      <freejoint name="ball_free"/>
      <geom name="ball_geom" class="ball" size="{BALL_R}" mass="{BALL_M}"/>
    </body>
  </worldbody>
  <actuator>
    <motor name="mx" joint="hinge_x" gear="1"/>
    <motor name="my" joint="hinge_y" gear="1"/>
  </actuator>
</mujoco>
""".strip()

# ==== 3) Env (EnvContext 호환) + 이산 액션 {0,1}^4 ====
N_BINS = 2
BINS   = np.array([0.0, 1.0], dtype=np.float32)
def idx_to_u4(idx, n_bins=N_BINS, bins=BINS):
    u = np.empty(4, dtype=np.float32)
    for k in range(4):
        u[k] = bins[idx % n_bins]
        idx //= n_bins
    return u  # [U,R,D,L]

@dataclass
class PlateConfig:
    dt: float = DT
    substeps: int = 1
    max_steps: int = 3000
    tau_max_x: float = 0.2
    tau_max_y: float = 0.2

class TiltPlateEnv(Env):
    metadata = {"render_modes": []}
    def __init__(self, env_config=None):
        # EnvContext → dict
        try:
            from ray.rllib.env.env_context import EnvContext
            if isinstance(env_config, EnvContext):
                env_config = dict(env_config)
        except Exception:
            pass
        if env_config is None:
            env_config = {}
        self.cfg = PlateConfig(
            dt       = env_config.get("dt", DT),
            substeps = env_config.get("substeps", 1),
            max_steps= env_config.get("max_steps", 3000),
            tau_max_x= env_config.get("tau_max_x", 0.2),
            tau_max_y= env_config.get("tau_max_y", 0.2),
        )
        # MuJoCo model/data
        self.model = mj.MjModel.from_xml_string(xml)
        self.data  = mj.MjData(self.model)
        # delays
        self.n_sites = len(sites)
        self.site_xy = [(x,y) for (x,y,_,_) in sites]
        self.ks = np.array([manhattan_delay_steps(x,y) for (x,y) in self.site_xy], dtype=np.int32)
        self.delay = DelayLine(int(self.ks.max()), self.n_sites)
        # action/obs
        self.action_space = spaces.Discrete(N_BINS**4)  # 16
        high = np.full((8 + self.n_sites,), np.inf, dtype=np.float32)
        self.observation_space = spaces.Box(-high, high, dtype=np.float32)
        self.step_count = 0
        self._reset_ball()

    def _reset_ball(self):
        self.data.qpos[:] = 0.0; self.data.qvel[:] = 0.0
        self.data.qpos[0] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[1] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[2] = BALL_R + 0.005
        self.delay = DelayLine(int(self.ks.max()), self.n_sites)

    def _sense_now(self):
        bx, by = self.data.qpos[0], self.data.qpos[1]
        sig = []
        for (x,y) in self.site_xy:
            d = abs(bx-x) + abs(by-y)
            sig.append(math.exp(-d/0.05))
        return np.array(sig, dtype=np.float32)

    def _obs(self):
        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        jx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_x")
        jy = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_y")
        hx = self.data.qpos[self.model.jnt_qposadr[jx]]
        hy = self.data.qpos[self.model.jnt_qposadr[jy]]
        hvx = self.data.qvel[self.model.jnt_dofadr[jx]]
        hvy = self.data.qvel[self.model.jnt_dofadr[jy]]
        delayed = self.delay.read_delayed(self.ks)
        return np.concatenate([[bx,by,vx,vy,hx,hy,hvx,hvy], delayed]).astype(np.float32)

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self._reset_ball()
        self.step_count = 0
        self.delay.push(np.zeros(self.n_sites, dtype=np.float32))
        return self._obs(), {}

    def step(self, action):
        u4 = idx_to_u4(int(action))  # {0,1}^4
        tau_x, tau_y = edge4_to_torque2(u4, max_tau=(self.cfg.tau_max_x, self.cfg.tau_max_y))
        self.data.ctrl[0] = tau_x; self.data.ctrl[1] = tau_y
        # push current sensors, step physics
        self.delay.push(self._sense_now())
        for _ in range(self.cfg.substeps):
            mj.mj_step(self.model, self.data)
        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        r = math.hypot(bx, by)
        speed = math.hypot(vx, vy)
        reward = -r - 0.1*speed - 0.001*float(np.sum(u4*u4))  # 이완 유도
        self.step_count += 1
        terminated = self.step_count >= self.cfg.max_steps
        truncated = False
        return self._obs(), reward, terminated, truncated, {}

# ==== 4) SNN ↔ Env 어댑터 (CerebellarNet 사용) ====
def encode_mf_spikes(obs, rate_scale=60.0, dt=DT):
    x = np.asarray(obs, dtype=np.float32)
    rng = float(np.ptp(x))
    x = (x - float(x.min())) / (rng + 1e-6) if rng >= 1e-12 else np.zeros_like(x, dtype=np.float32)
    lam = x * rate_scale * dt
    return (np.random.rand(x.shape[0]) < lam).astype(np.float32)

def motor_spikes_to_u4(mo_spk):
    return mo_spk.astype(np.float32)  # {0,1}^4

def cf_from_transition(prev_r, curr_r, k_pos=+1.0, k_neg=-1.0):
    dr = curr_r - prev_r
    if dr < 0:  return k_pos
    if dr > 0:  return k_neg
    return 0.0

def train_one_episode_snn(env: TiltPlateEnv, cereb, steps=2000, rate_scale=60.0):
    obs, _ = env.reset()
    cereb.reset()
    bx, by = env.data.qpos[0], env.data.qpos[1]
    prev_r = float(math.hypot(bx, by))
    ep_ret = 0.0
    for t in range(steps):
        # forward with MF spikes
        mf_spk = encode_mf_spikes(obs, rate_scale=rate_scale, dt=DT)
        mo_spk = cereb.forward(mf_spk, cf_spikes=None)   # (4,) binary
        # apply action
        u4 = motor_spikes_to_u4(mo_spk)
        tau_x, tau_y = edge4_to_torque2(u4, max_tau=(env.cfg.tau_max_x, env.cfg.tau_max_y))
        env.data.ctrl[0] = tau_x; env.data.ctrl[1] = tau_y
        # physics step (noop action since ctrl already set)
        obs, reward, term, trunc, _ = env.step(0)
        # CF and learning
        bx, by = env.data.qpos[0], env.data.qpos[1]
        curr_r = float(math.hypot(bx, by))
        cf = cf_from_transition(prev_r, curr_r, k_pos=+1.0, k_neg=-1.0)
        cereb.learn(cf_signal=cf)
        prev_r = curr_r
        ep_ret += reward
        if term or trunc:
            break
    return ep_ret, t+1

print("Model OK. Sites:", len(sites), "Max delay steps:", int(np.max([manhattan_delay_steps(x,y) for (x,y,_,_) in sites])))


Model OK. Sites: 100 Max delay steps: 60


In [7]:
# ==== SNN 단독 트레이닝 실행 ====
# (네 CerebellarNet 클래스 정의가 끝난 뒤)
cfg_sn = SimulationConfig(  # 네가 쓰는 dataclass에 맞춰 값 전달
    n_mf=108, n_grc=512, n_pkj=8, n_motor=4,
    tau_pre_ms=5.0, tau_post_ms=5.0, tau_e_ms=50.0,
    dt_ms=0.1, eta=1e-3, max_dw_per_step=1e-3, scale_every=100,
    use_mf_collat=True, mf_collat_gain=0.3
)
snn = CerebellarNet(cfg_sn, device=None)

env_sn = TiltPlateEnv(env_config={"tau_max_x":0.2, "tau_max_y":0.2})
ret, nstep = train_one_episode_snn(env_sn, snn, steps=2000, rate_scale=60.0)
print(f"SNN-only episode return={ret:.3f}, steps={nstep}")


SNN-only episode return=-88.428, steps=2000


In [8]:
!pip install gputil



In [9]:
import ray
from ray.rllib.algorithms.dqn import DQNConfig

ray.shutdown(); ray.init(ignore_reinit_error=True, include_dashboard=False)

cfg = (
    DQNConfig()
    .environment(
        env=TiltPlateEnv,            # 클래스로!
        env_config={"tau_max_x":0.2, "tau_max_y":0.2}
    )
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=True)
    .env_runners(num_env_runners=0) # Colab 단일 프로세스
)

algo = cfg.build()
res = algo.train()
print({
    "len_mean": res.get("env_runners", {}).get("episode_len_mean"),
    "ret_mean": res.get("env_runners", {}).get("episode_return_mean"),
})
algo.stop(); ray.shutdown()


2025-09-22 02:36:24,342	INFO worker.py:1951 -- Started a local Ray instance.
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


{'len_mean': None, 'ret_mean': None}


In [10]:
# ==== 0) Imports & constants ====
import math, numpy as np
import mujoco as mj
from gymnasium import Env, spaces
from dataclasses import dataclass
from typing import Optional

DT = 1e-4   # 0.1 ms
V  = 50.0   # m/s
PLATE_HALF = 0.15
SITES_N = 10
BALL_R  = 0.015
BALL_M  = 0.03

# ==== 1) Delay + grid ====
class DelayLine:
    def __init__(self, max_steps, n_channels):
        self.buf = np.zeros((max_steps+1, n_channels), dtype=np.float32)
        self.ptr = 0
        self.max_steps = max_steps
        self.n = n_channels
    def push(self, x_t):
        self.buf[self.ptr, :] = x_t
        self.ptr = (self.ptr + 1) % (self.max_steps+1)
    def read_delayed(self, ks):
        idx = (self.ptr - 1 - ks) % (self.max_steps+1)
        return self.buf[idx, np.arange(self.n)]

def manhattan_delay_steps(x, y, dt=DT, v=V):
    return int(np.ceil((abs(x)+abs(y)) / (v*dt)))

def make_site_grid(n=SITES_N, half=PLATE_HALF, z=0.006):
    xs = np.linspace(-half, half, n); ys = np.linspace(-half, half, n)
    return [(x, y, z, f"sn_i{i}_j{j}") for i,y in enumerate(ys) for j,x in enumerate(xs)]

def site_xml_lines(pts):
    return "\n      ".join(
        f'<site name="{name}" pos="{x:.4f} {y:.4f} {z:.4f}" size="0.002" type="sphere" rgba="0.2 0.8 0.2 0.5"/>'
        for (x,y,z,name) in pts
    )

# 4-edge -> 2-hinge torque
def edge4_to_torque2(u4, w=(1.,1.,1.,1.), max_tau=(0.2,0.2)):
    uU, uR, uD, uL = u4; wU, wR, wD, wL = w
    tau_y = (wU*uU - wD*uD) * max_tau[1]
    tau_x = (wR*uR - wL*uL) * max_tau[0]
    return float(np.clip(tau_x, -max_tau[0], max_tau[0])), float(np.clip(tau_y, -max_tau[1], max_tau[1]))

# ==== 2) MJCF XML ====
sites = make_site_grid()
xml = f"""
<mujoco model="tilt_plate">
  <compiler angle="degree" inertiafromgeom="true"/>
  <option timestep="{DT:.7f}" gravity="0 0 -9.81" integrator="RK4"/>
  <default>
    <geom  condim="6" margin="0.001" solimp="0.9 0.95 0.001" solref="0.002 1"/>
    <default class="plate">
      <geom type="box" friction="0.8 0.003 0.001" rgba="0.8 0.8 0.85 1"/>
    </default>
    <default class="ball">
      <geom type="sphere" friction="0.9 0.005 0.002" rgba="0.9 0.3 0.3 1"/>
    </default>
    <joint armature="0.002" damping="0.1" limited="true"/>
    <motor gear="1.0" ctrllimited="true" ctrlrange="-1.0 1.0"/>
  </default>
  <worldbody>
    <body name="plate_base" pos="0 0 0">
      <joint name="hinge_x" type="hinge" axis="1 0 0" range="-5 5"/>
      <joint name="hinge_y" type="hinge" axis="0 1 0" range="-5 5"/>
      <geom name="plate_geom" class="plate" size="{PLATE_HALF} {PLATE_HALF} 0.005" mass="1.0"/>
      {site_xml_lines(sites)}
    </body>
    <body name="ball" pos="0 0 {BALL_R+0.01:.4f}">
      <freejoint name="ball_free"/>
      <geom name="ball_geom" class="ball" size="{BALL_R}" mass="{BALL_M}"/>
    </body>
  </worldbody>
  <actuator>
    <motor name="mx" joint="hinge_x" gear="1"/>
    <motor name="my" joint="hinge_y" gear="1"/>
  </actuator>
</mujoco>
""".strip()

# ==== 3) Discrete actions {0,1}^4 ====
N_BINS = 2
BINS   = np.array([0.0, 1.0], dtype=np.float32)
def idx_to_u4(idx, n_bins=N_BINS, bins=BINS):
    u = np.empty(4, dtype=np.float32)
    for k in range(4):
        u[k] = bins[idx % n_bins]; idx //= n_bins
    return u  # [U,R,D,L]

@dataclass
class PlateConfig:
    dt: float = DT
    substeps: int = 1
    max_steps: int = 3000
    tau_max_x: float = 0.2
    tau_max_y: float = 0.2

# ==== 4) SNN 유틸 ====
def encode_mf_spikes(obs, rate_scale=60.0, dt=DT):
    x = np.asarray(obs, dtype=np.float32)
    rng = float(np.ptp(x))
    x = (x - float(x.min())) / (rng + 1e-6) if rng >= 1e-12 else np.zeros_like(x, dtype=np.float32)
    lam = x * rate_scale * dt
    return (np.random.rand(x.shape[0]) < lam).astype(np.float32)

def cf_from_transition(prev_r, curr_r, k_pos=+1.0, k_neg=-1.0):
    dr = curr_r - prev_r
    return k_pos if dr < 0 else (k_neg if dr > 0 else 0.0)

# ==== 5) Env with internal CerebellarNet (SNN 임베딩 반환) ====
class TiltPlateSNNDQNEnv(Env):
    metadata = {"render_modes": []}
    def __init__(self, env_config=None):
        # EnvContext -> dict (RLlib 호환)
        try:
            from ray.rllib.env.env_context import EnvContext
            if isinstance(env_config, EnvContext):
                env_config = dict(env_config)
        except Exception:
            pass
        env_config = env_config or {}
        self.cfg = PlateConfig(
            dt       = env_config.get("dt", DT),
            substeps = env_config.get("substeps", 1),
            max_steps= env_config.get("max_steps", 3000),
            tau_max_x= env_config.get("tau_max_x", 0.2),
            tau_max_y= env_config.get("tau_max_y", 0.2),
        )

        # MuJoCo
        self.model = mj.MjModel.from_xml_string(xml)
        self.data  = mj.MjData(self.model)
        self.site_xy = [(x,y) for (x,y,_,_) in sites]
        self.ks = np.array([manhattan_delay_steps(x,y) for (x,y) in self.site_xy], dtype=np.int32)
        self.delay = DelayLine(int(self.ks.max()), len(self.site_xy))

        # --- CerebellarNet 내부 생성 ---
        # SimulationConfig는 네가 정의한 dataclass 사용
        self.sn_cfg = SimulationConfig(
            n_mf=108, n_grc=512, n_pkj=8, n_motor=4,
            tau_pre_ms=5.0, tau_post_ms=5.0, tau_e_ms=50.0,
            dt_ms=0.1, eta=1e-3, max_dw_per_step=1e-3, scale_every=100,
            use_mf_collat=True, mf_collat_gain=0.3
        )
        self.snn = CerebellarNet(self.sn_cfg)
        self.prev_r = None

        # 관측을 SNN 임베딩 z로 축소 (여기선 간단히 motor 스파이크 4개 + PKJ 스파이크 8개 = 12차원)
        self.z_dim = 12
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(self.z_dim,), dtype=np.float32)

        # 액션은 DQN용 이산 16개
        self.action_space = spaces.Discrete(N_BINS**4)
        self.step_count = 0
        self._reset_ball()

    # ---- 내부: raw obs 생성 (기존과 동일) ----
    def _obs_raw(self):
        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        jx = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_x")
        jy = mj.mj_name2id(self.model, mj.mjtObj.mjOBJ_JOINT, "hinge_y")
        hx = self.data.qpos[self.model.jnt_qposadr[jx]]
        hy = self.data.qpos[self.model.jnt_qposadr[jy]]
        hvx = self.data.qvel[self.model.jnt_dofadr[jx]]
        hvy = self.data.qvel[self.model.jnt_dofadr[jy]]
        delayed = self.delay.read_delayed(self.ks)
        return np.concatenate([[bx,by,vx,vy,hx,hy,hvx,hvy], delayed]).astype(np.float32)

    # ---- SNN 임베딩: obs_raw -> spikes -> forward -> z ----
    def _embed_with_snn(self, obs_raw):
        mf_spk = encode_mf_spikes(obs_raw, rate_scale=60.0, dt=DT)
        mo_spk = self.snn.forward(mf_spk, cf_spikes=None)    # (4,) binary
        # PKJ 스파이크는 캐시에 있음
        pkj_spk = self.snn._last['pkj']
        z = np.zeros(self.z_dim, dtype=np.float32)
        z[:4] = mo_spk
        z[4:12] = pkj_spk[:8] if pkj_spk.shape[0] >= 8 else np.pad(pkj_spk, (0, 8-pkj_spk.shape[0]))
        return z, mo_spk  # z: 관측, mo_spk: 내부 행동 생성용

    def _reset_ball(self):
        self.data.qpos[:] = 0.0; self.data.qvel[:] = 0.0
        self.data.qpos[0] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[1] = np.random.uniform(-0.03, 0.03)
        self.data.qpos[2] = BALL_R + 0.005
        self.delay = DelayLine(int(self.ks.max()), len(self.site_xy))
        self.snn.reset()
        bx, by = self.data.qpos[0], self.data.qpos[1]
        self.prev_r = float(math.hypot(bx, by))

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self._reset_ball()
        self.step_count = 0
        self.delay.push(np.zeros(len(self.site_xy), dtype=np.float32))
        obs_raw = self._obs_raw()
        z, _ = self._embed_with_snn(obs_raw)  # 초기 z
        return z, {}

    def step(self, action):
        # 1) SNN이 제안한 모터 스파이크 기반 u4를 기본으로 하되,
        #    DQN의 action(=16-way)을 "게이트"로 사용: action이 0이면 SNN u4 그대로,
        #    그 외에는 idx_to_u4(action)로 override (간단한 결합 전략; 필요 시 바꿔도 됨)
        obs_raw = self._obs_raw()
        z, mo_spk = self._embed_with_snn(obs_raw)
        u4_snn = mo_spk.astype(np.float32)
        u4_dqn = idx_to_u4(int(action))
        u4 = u4_snn if int(action) == 0 else u4_dqn  # policy-mix 예시

        tau_x, tau_y = edge4_to_torque2(u4, max_tau=(self.cfg.tau_max_x, self.cfg.tau_max_y))
        self.data.ctrl[0] = tau_x; self.data.ctrl[1] = tau_y

        # 센서 push + 물리 스텝
        self.delay.push(self._sense_now())
        for _ in range(self.cfg.substeps):
            mj.mj_step(self.model, self.data)

        # 보상/종료
        bx, by = self.data.qpos[0], self.data.qpos[1]
        vx, vy = self.data.qvel[0], self.data.qvel[1]
        r = math.hypot(bx, by)
        speed = math.hypot(vx, vy)
        reward = -r - 0.1*speed - 0.001*float(np.sum(u4*u4))

        # CF 학습
        cf = cf_from_transition(self.prev_r, r, k_pos=+1.0, k_neg=-1.0)
        self.snn.learn(cf_signal=cf)
        self.prev_r = r

        # 다음 관측(z) 업데이트
        obs_raw2 = self._obs_raw()
        z_next, _ = self._embed_with_snn(obs_raw2)

        self.step_count += 1
        terminated = self.step_count >= self.cfg.max_steps
        truncated = False
        return z_next, reward, terminated, truncated, {}

    def _sense_now(self):
        bx, by = self.data.qpos[0], self.data.qpos[1]
        return np.array([math.exp(- (abs(bx-x)+abs(by-y))/0.05) for (x,y) in self.site_xy], dtype=np.float32)

# ==== 6) 빠른 sanity check ====
env = TiltPlateSNNDQNEnv()
z, _ = env.reset(seed=42)
print("z_dim:", z.shape)
for _ in range(5):
    z, r, term, trunc, _ = env.step(0)  # action=0이면 SNN u4 사용
print("step OK; last reward:", r)


z_dim: (12,)
step OK; last reward: -0.04558527545711828


In [11]:
import ray
from ray.rllib.algorithms.dqn import DQNConfig

ray.shutdown(); ray.init(ignore_reinit_error=True, include_dashboard=False)

cfg = (
    DQNConfig()
    .environment(
        env=TiltPlateSNNDQNEnv,           # ← 클래스로!
        env_config={"tau_max_x":0.2, "tau_max_y":0.2}
    )
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=True)
    .env_runners(num_env_runners=4, num_gpus_per_env_runner=0.25)      # ← rollouts(...) 대체
)

algo = cfg.build()
res = algo.train()
print({
    "len_mean": res.get("env_runners", {}).get("episode_len_mean"),
    "ret_mean": res.get("env_runners", {}).get("episode_return_mean"),
})
algo.stop(); ray.shutdown()


2025-09-22 02:36:56,004	INFO worker.py:1951 -- Started a local Ray instance.
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


{'len_mean': None, 'ret_mean': None}


In [12]:
# === B. RLlib DQN 학습 ===
# 전제: TiltPlateSNNDQNEnv 가 정의되어 있음 (SNN 내장, 관측 z 반환, Discrete(16) 액션)
!pip -q install -U "ray[rllib]==2.49.2"

import ray
from ray.rllib.algorithms.dqn import DQNConfig

ray.shutdown()
ray.init(ignore_reinit_error=True, include_dashboard=False)

# 환경을 "클래스"로 넘기고 env_runners 사용 (최신 API)
cfg = (
    DQNConfig()
    .environment(
        env=TiltPlateSNNDQNEnv,                 # 클래스로 전달 (인스턴스 금지)  :contentReference[oaicite:4]{index=4}
        env_config={"tau_max_x":0.2, "tau_max_y":0.2}
    )
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=True)
    .env_runners(
        num_env_runners=0,                      # Colab 단일 프로세스
        # rollout_fragment_length 등 필요시 여기서 조정  :contentReference[oaicite:5]{index=5}
    )
    # (옵션) 학습 하이퍼파라미터 일부
    .training(
        gamma=0.99,
        lr=1e-3,
        train_batch_size=512,
        replay_buffer_config={"capacity": 50000},
        n_step=1,
        target_network_update_freq=500,
    )
)

algo = cfg.build()

# 짧게 몇 iteration만 데모 학습
for i in range(3):
    res = algo.train()
    print(f"[Iter {i+1}] len_mean={res.get('env_runners', {}).get('episode_len_mean')}, "
          f"ret_mean={res.get('env_runners', {}).get('episode_return_mean')}")

# 체크포인트 저장
ckpt = algo.save()
print("Saved checkpoint:", ckpt)
algo.stop()
ray.shutdown()


2025-09-22 02:37:31,472	INFO worker.py:1951 -- Started a local Ray instance.
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


[Iter 1] len_mean=None, ret_mean=None


KeyboardInterrupt: 

In [None]:
import torch, os
print("CUDA available:", torch.cuda.is_available())
print("Visible GPUs:", os.environ.get("CUDA_VISIBLE_DEVICES", "all"))


In [None]:
from ray.rllib.algorithms.dqn import DQNConfig

cfg = (
    DQNConfig()
    .environment(env=TiltPlateSNNDQNEnv,
    env_config={"tau_max_x":0.2, "tau_max_y":0.2, "max_steps":500})
    .framework("torch")
    .api_stack(enable_rl_module_and_learner=True)
    .resources(num_gpus=1)  # <- 핵심
    .env_runners(
        num_env_runners=0,                 # Colab 단일 프로세스
        rollout_fragment_length=2000,      # ← 조각 길이 상향
    )
    .training(
        train_batch_size=2000,             # 조각 합쳐 도달
    )
    .evaluation(
        evaluation_interval=1,             # ← 매 iter 평가
        evaluation_duration=3,             # 평가 에피소드 수
        evaluation_duration_unit="episodes"
    )
)
print("done config")
algo = cfg.build()
for i in range(100):
    res = algo.train()
    print(f"[Iter {i+1}] len_mean={res.get('env_runners', {}).get('episode_len_mean')}, "
          f"ret_mean={res.get('env_runners', {}).get('episode_return_mean')}")

    print("eval :", res.get("evaluation",{}).get("env_runners",{}).get("episode_len_mean"),
                res.get("evaluation",{}).get("env_runners",{}).get("episode_return_mean"))
# 체크포인트 저장
ckpt = algo.save()
print("Saved checkpoint:", ckpt)
algo.stop()
ray.shutdown()

done config


  logger.warn(f"Overriding environment {new_spec.id} already in registry.")
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


[Iter 1] len_mean=1000.0, ret_mean=-54.680862076513634
eval : 1000.0 -133.98082907302958
[Iter 2] len_mean=1000.0, ret_mean=-59.28181991454788
eval : 1000.0 -88.80728995569656
[Iter 3] len_mean=1000.0, ret_mean=-55.02706563014721
eval : 1000.0 -115.33785168735693
[Iter 4] len_mean=1000.0, ret_mean=-64.15851102717451
eval : 1000.0 -125.59594962859923
[Iter 5] len_mean=1000.0, ret_mean=-77.50992528969081
eval : 1000.0 -124.0207069650441
[Iter 6] len_mean=1000.0, ret_mean=-83.34871428214757
eval : 1000.0 -129.64167613052365
[Iter 7] len_mean=1000.0, ret_mean=-91.4293717612594
eval : 1000.0 -135.68638617273302
[Iter 8] len_mean=1000.0, ret_mean=-98.05232661974252
eval : 1000.0 -134.4524216608119
[Iter 9] len_mean=1000.0, ret_mean=-100.57023904651383
eval : 1000.0 -132.40050855945287
[Iter 10] len_mean=1000.0, ret_mean=-101.85493819882281
eval : 1000.0 -130.72013573421626
[Iter 11] len_mean=1000.0, ret_mean=-102.57004056947343
eval : 1000.0 -130.3519952872602
[Iter 12] len_mean=1000.0, ret_