In [None]:
from __future__ import annotations
import torch
import numpy as np

def _to_tensor(x: np.ndarray) -> torch.Tensor:
    return torch.as_tensor(x, dtype=torch.float32)

class AttackWrapper:
    """Base class: wraps (policy, value_fn) to perturb observation before acting.
    Assumes observations are already normalized to [-1,1].
    """
    def __init__(self, model, epsilon: float = 0.01, device: str = "cpu") -> None:
        self.model = model
        self.eps = epsilon
        self.device = device

    def perturb(self, obs: np.ndarray) -> np.ndarray:
        return obs

    @torch.no_grad()
    def act(self, obs: np.ndarray):
        obs_adv = self.perturb(obs)
        action, _ = self.model.predict(obs_adv, deterministic=True)
        return action, obs_adv

class FGSMAttack(AttackWrapper):
    """FGSM wrt policy mean action output.
    For SB3 PPO, we pass obs through model.policy to get latent and mean action.
    """
    def perturb(self, obs: np.ndarray) -> np.ndarray:
        self.model.policy.set_training_mode(True)  # enable grad path
        obs_t = _to_tensor(obs).to(self.device)
        if obs_t.ndim == 1:
            obs_t = obs_t.unsqueeze(0)
        obs_t.requires_grad_(True)

        # Forward to get mean action (pre squashing)
        features = self.model.policy.extract_features(obs_t)
        latent_pi, _ = self.model.policy.mlp_extractor(features)
        mean_actions = self.model.policy.action_net(latent_pi)  # shape [B, act_dim]
        # Sum to get scalar objective (maximize change magnitude)
        obj = (mean_actions**2).sum()
        obj.backward()
        grad = obs_t.grad
        sign = torch.sign(grad)
        adv = torch.clamp(obs_t + self.eps * sign, -1.0, 1.0)
        return adv.detach().cpu().numpy()

class OIAttack(AttackWrapper):
    """Optimism Induction Attack: pushes obs to increase the critic's V(s).
    """
    def perturb(self, obs: np.ndarray) -> np.ndarray:
        self.model.policy.set_training_mode(True)
        obs_t = _to_tensor(obs).to(self.device)
        if obs_t.ndim == 1:
            obs_t = obs_t.unsqueeze(0)
        obs_t.requires_grad_(True)

        features = self.model.policy.extract_features(obs_t)
        _, latent_vf = self.model.policy.mlp_extractor(features)
        values = self.model.policy.value_net(latent_vf)  # shape [B,1]
        obj = values.sum()   # maximize value
        obj.backward()
        grad = obs_t.grad
        sign = torch.sign(grad)
        adv = torch.clamp(obs_t + self.eps * sign, -1.0, 1.0)
        return adv.detach().cpu().numpy()
