In [1]:
from __future__ import annotations
import numpy as np
import gymnasium as gym
from gymnasium import spaces

class ACCEnv(gym.Env):
    """
    1D ACC environment with a headway-based CBF-style safety filter.

    Observation s = [Δx, Δv, v]
      Δx: lead - ego headway (m)
      Δv: lead - ego relative speed (m/s)
      v : ego speed (m/s)

    If normalize_obs=True, observations are mapped to [-1, 1] with fixed ranges:
      Δx ∈ [0, 200], Δv ∈ [-20, 20], v ∈ [0, 30].
    Attack budgets (epsilon) are interpreted in this normalized space.
    """
    metadata = {"render.modes": []}

    def __init__(
        self,
        dt: float = 0.1,
        v_ref: float = 15.0,
        a_min: float = -3.5,
        a_max: float = 2.0,
        Th: float = 1.5,
        d0: float = 5.0,
        w_v: float = 0.5,
        w_s: float = 2.0,
        w_a: float = 0.01,
        lead_v0: float = 15.0,
        brake_profile: bool = False,
        brake_start_s: float = 5.0,
        brake_dur_s: float = 3.0,
        lead_decel: float = -2.0,
        episode_seconds: float = 20.0,
        seed: int | None = None,
        normalize_obs: bool = True,
        obs_clip: float = 1.0,
    ) -> None:
        super().__init__()
        self.dt = dt
        self.v_ref = v_ref
        self.a_min = a_min
        self.a_max = a_max
        self.Th = Th
        self.d0 = d0
        self.w_v = w_v
        self.w_s = w_s
        self.w_a = w_a
        self.lead_v0 = lead_v0
        self.brake_profile = brake_profile
        self.brake_start_s = brake_start_s
        self.brake_dur_s = brake_dur_s
        self.lead_decel = lead_decel
        self.episode_steps = int(episode_seconds / dt)
        self.normalize_obs = normalize_obs
        self.obs_clip = obs_clip

        self.np_random, _ = gym.utils.seeding.np_random(seed)

        # Normalization ranges
        self._x_range = (0.0, 200.0)
        self._dv_range = (-20.0, 20.0)
        self._v_range = (0.0, 30.0)

        # Observation/action space
        high = np.array([1.0, 1.0, 1.0], dtype=np.float32) if normalize_obs else np.array([np.inf, np.inf, np.inf], dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
        self.action_space = spaces.Box(
            low=np.array([self.a_min], dtype=np.float32),
            high=np.array([self.a_max], dtype=np.float32),
            dtype=np.float32,
        )

        # Safety-filter observation override (consumed on next step)
        self._safety_obs_override = None

        self.reset()

    # ---------- helper: normalization ----------
    def _norm(self, s_raw: np.ndarray) -> np.ndarray:
        Δx, Δv, v = s_raw
        def _scale(val, lo, hi):
            z = (val - lo) / (hi - lo + 1e-8)
            return np.clip(2.0 * z - 1.0, -self.obs_clip, self.obs_clip)
        return np.array([
            _scale(Δx, *self._x_range),
            _scale(Δv, *self._dv_range),
            _scale(v,  *self._v_range),
        ], dtype=np.float32)

    def _denorm(self, s_norm: np.ndarray) -> np.ndarray:
        def _inv(vn, lo, hi):
            return (vn + 1.0) * 0.5 * (hi - lo) + lo
        Δx = _inv(s_norm[0], *self._x_range)
        Δv = _inv(s_norm[1], *self._dv_range)
        v  = _inv(s_norm[2], *self._v_range)
        return np.array([Δx, Δv, v], dtype=np.float32)

    # ---------- public hook: tell safety filter which observation to use ----------
    def set_safety_obs_for_filter(self, obs_norm_or_raw: np.ndarray) -> None:
        """
        Provide the observation (normalized if self.normalize_obs=True, else raw)
        that the safety filter should use for THIS step. It will be consumed once.
        """
        self._safety_obs_override = np.array(obs_norm_or_raw, dtype=np.float32)

    # ---------- safety filter ----------
    def _amax_safe(self, Δx: float, Δv: float, v: float) -> float:
        # Eq.(5): a_max_safe = (Δx - Th*v + Δv*dt) / (Th*dt)
        return (Δx - self.Th * v + Δv * self.dt) / (self.Th * self.dt + 1e-8)

    def _apply_safety(self, a_rl: float, Δx: float, Δv: float, v: float) -> float:
        a_safe_max = self._amax_safe(Δx, Δv, v)
        a_clamped = min(a_rl, a_safe_max)
        return float(np.clip(a_clamped, self.a_min, self.a_max))

    # Optional alias to match paper naming
    def safety_filter(self, state_raw: np.ndarray, action: float) -> float:
        Δx, Δv, v = state_raw
        return self._apply_safety(float(action), float(Δx), float(Δv), float(v))

    # ---------- lead car ----------
    def _lead_step(self):
        t = self._t * self.dt
        if self.brake_profile and (self.brake_start_s <= t < self.brake_start_s + self.brake_dur_s):
            self.v_l = max(self.v_l + self.lead_decel * self.dt, 0.0)
        self.x_l = self.x_l + self.v_l * self.dt

    # ---------- obs / reward ----------
    def _get_obs(self) -> np.ndarray:
        Δx = self.x_l - self.x_e
        Δv = self.v_l - self.v_e
        s_raw = np.array([Δx, Δv, self.v_e], dtype=np.float32)
        return self._norm(s_raw) if self.normalize_obs else s_raw

    def _reward(self) -> float:
        Δx = self.x_l - self.x_e
        Δv = self.v_l - self.v_e
        v  = self.v_e
        d_safe = self.d0 + self.Th * v
        r_speed = - self.w_v * (v - self.v_ref)**2
        r_safe  = - self.w_s * max(0.0, d_safe - Δx)**2
        r_act   = - self.w_a * (self.a_prev**2)
        return float(r_speed + r_safe + r_act)

    # ---------- Gym API ----------
    def reset(self, *, seed: int | None = None, options=None):
        if seed is not None:
            self.np_random, _ = gym.utils.seeding.np_random(seed)
        # Ego near target speed; headway ~30-50 m
        self.x_e = 0.0
        self.v_e = float(np.clip(self.v_ref + self.np_random.normal(0, 0.5), 0.0, self._v_range[1]))
        self.a_prev = 0.0
        self.x_l = float(self.np_random.uniform(30.0, 50.0))
        self.v_l = self.lead_v0
        self._t = 0
        self._collision = False
        self._safety_obs_override = None
        obs = self._get_obs()
        info = {}
        return obs, info

    def step(self, action):
        # Proposed acceleration from RL policy (raw units)
        a_rl = float(np.clip(action, self.a_min, self.a_max)[0])

        # Use the attacked/observed state for safety if provided; else regular obs.
        s_used = self._safety_obs_override if self._safety_obs_override is not None else self._get_obs()
        self._safety_obs_override = None  # consume once

        s_used_raw = self._denorm(s_used) if self.normalize_obs else s_used
        Δx_used, Δv_used, v_used = float(s_used_raw[0]), float(s_used_raw[1]), float(s_used_raw[2])

        # Safety clamp
        a = self._apply_safety(a_rl, Δx_used, Δv_used, v_used)

        # Ego dynamics
        self.x_e = self.x_e + self.v_e * self.dt
        self.v_e = max(self.v_e + a * self.dt, 0.0)
        self.a_prev = a

        # Lead dynamics
        self._lead_step()

        # Collision check
        if (self.x_l - self.x_e) <= 0.0:
            self._collision = True

        obs = self._get_obs()
        reward = self._reward()
        self._t += 1
        terminated = self._collision
        truncated = self._t >= self.episode_steps
        info = {"collision": self._collision, "a": a, "v": self.v_e, "Δx": self.x_l - self.x_e}
        return obs, reward, terminated, truncated, info
