In [None]:
"""
Immune RL with PPO + Gated ReLU (Multistep-Only) — With Improvement Tracking & Failure Presets
-----------------------------------------------------------------------------------------------
- Two data modes:
    * peptide (default): peptide/TCR/allele inputs using ImmuneNet (sequence encoders)
    * pbmc: PBMC10k multiome fused vectors using VectorNet (optional muon/mudatasets, PyG for GAT)
- Trains base net once (fixed activation: convex or nonconvex) and reuses weights across PPO runs.
- PPO actor/critic activation varies per run: convex, nonconvex, or twostage (convex→nonconvex switch).
- Continuous cytokine control with safe random perturbations and multistep discounted returns.
- Oversampling on TRAIN subset (optional): 'target_bin' or 'allele' (peptide mode only).
- Tracks per-setup BEFORE vs AFTER metrics and writes deltas to performance_improvements.csv
- Saves raw results to gated_relu_multistep.csv and optional per-step reward curves to ppo_multistep_rewards.csv
- Optional presets to INDUCE local optima behavior for demonstration:
    * --failure_preset convex_flat       → convex-only gets stuck in a flat, underfit basin
    * --failure_preset nonconvex_sharp   → nonconvex-only gets trapped in a sharp local minimum

Usage examples
--------------
# Peptide mode (default)
python immune_rl_multistep_with_improvements.py \
  --train train_BA1.txt --test test_BA1.txt --alleles allelelist.txt \
  --immune_act_mode convex --ppo_modes convex,nonconvex,twostage \
  --epochs 8 --ppo_episodes 20 --log_step_rewards

# PBMC mode (requires muon+mudatasets; falls back to raw features if PyG not installed)
python immune_rl_multistep_with_improvements.py --data_mode pbmc --pbmc_sample 8000 \
  --immune_act_mode convex --ppo_modes twostage,nonconvex --epochs 5 --ppo_episodes 10

# Demonstrate local optima (convex flat or nonconvex sharp)
python immune_rl_multistep_with_improvements.py --failure_preset convex_flat --ppo_modes convex
python immune_rl_multistep_with_improvements.py --failure_preset nonconvex_sharp --ppo_modes nonconvex
"""

from __future__ import annotations
import os, csv, argparse, random, warnings
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from copy import deepcopy
from scipy.stats import pearsonr

warnings.filterwarnings("ignore", category=RuntimeWarning)

# =========================
# Optional deps
# =========================
HAS_TORCH = True
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
except Exception:
    HAS_TORCH = False

HAS_PYG = False
if HAS_TORCH:
    try:
        from torch_geometric.nn import GATConv
        HAS_PYG = True
    except Exception:
        HAS_PYG = False

HAS_MUON = False
try:
    import muon as mu
    import mudatasets as mds
    from scipy.sparse import csr_matrix
    from sklearn.neighbors import NearestNeighbors
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    HAS_MUON = True
except Exception:
    HAS_MUON = False

# =========================
# Constants & utils
# =========================
AA_VOCAB = list("ACDEFGHIKLMNPQRSTVWY")
AA_TO_ID = {a: i + 1 for i, a in enumerate(AA_VOCAB)}
CYTOKINES = ["NONE", "IL2", "IFNG", "IL10", "TNFA"]
CYTOKINE_TO_ID = {c: i for i, c in enumerate(CYTOKINES)}


def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    if HAS_TORCH:
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)


@dataclass
class Example:
    peptide: str
    allele: str
    score: float
    motif: Optional[str] = None
    tcr: Optional[str] = None


# =========================
# I/O + Tokenizer (peptide mode)
# =========================

def smart_read_table(path: str) -> List[List[str]]:
    rows = []
    with open(path, "r", newline="") as f:
        sample = f.read(2048); f.seek(0)
        import csv as _csv
        try:
            dialect = _csv.Sniffer().sniff(sample, delimiters="\t,;")
        except Exception:
            class dialect: delimiter = ","
        reader = _csv.reader(f, dialect)
        for row in reader:
            if row: rows.append([c.strip() for c in row])
    return rows


def load_alleles(path: str) -> Dict[str, int]:
    uniq = []
    for r in smart_read_table(path):
        for c in r:
            for token in c.replace(",", " ").split():
                if token and token not in uniq:
                    uniq.append(token)
    return {a: i for i, a in enumerate(sorted(uniq))}


def parse_examples(path: str, allele_to_id: Dict[str, int]) -> List[Example]:
    rows = smart_read_table(path); exs = []
    for r in rows:
        if len(r) < 3: continue
        pep, score_str, allele = r[0], r[1], r[2]
        try:
            score = float(score_str)
        except Exception:
            continue
        if allele not in allele_to_id:
            allele_to_id[allele] = len(allele_to_id)
        exs.append(Example(peptide=pep, allele=allele, score=score))
    return exs


class SeqTokenizer:
    def __init__(self, max_len=32): self.max_len = max_len

    def encode_ids(self, s: str):
        ids = [AA_TO_ID.get(ch, 0) for ch in s[:self.max_len]]
        if len(ids) < self.max_len: ids += [0]*(self.max_len - len(ids))
        return ids

    def encode(self, s: str):
        arr = np.asarray(self.encode_ids(s), dtype=np.int64)
        return torch.tensor(arr, dtype=torch.long) if HAS_TORCH else arr


# =========================
# Loss + Metrics
# =========================

def nonconvex_loss(pred, target, eps: float = 1e-6):
    e = pred - target
    return torch.mean(torch.sqrt(torch.abs(e) + eps))


def regression_metrics(preds: np.ndarray, targets: np.ndarray) -> Dict[str, float]:
    preds = np.asarray(preds).flatten()
    targets = np.asarray(targets).flatten()
    mse = float(np.mean((preds - targets) ** 2))
    rmse = float(np.sqrt(mse))
    mae = float(np.mean(np.abs(preds - targets)))
    denom = float(np.sum((targets - np.mean(targets)) ** 2) + 1e-12)
    r2 = float(1.0 - np.sum((targets - preds) ** 2) / denom)
    pear = float(pearsonr(preds, targets)[0]) if len(preds) > 1 and np.std(preds) > 0 and np.std(targets) > 0 else 0.0
    return {"MSE": mse, "RMSE": rmse, "MAE": mae, "R2": r2, "Pearson": pear}


# =========================
# Gated ReLU + Encoders
# =========================
if HAS_TORCH:
    class GatedReLU(nn.Module):
        """
        Modes:
          - convex:    max(a*x + b, x)
          - nonconvex: min(a*x + b, x)
          - twostage:  start convex, switch to nonconvex after switch_epoch
        """
        def __init__(self, dim: int, mode="convex", switch_epoch=5):
            super().__init__()
            self.mode = mode
            self.switch_epoch = switch_epoch
            self.switched = False
            self.a = nn.Parameter(torch.ones(dim))
            self.b = nn.Parameter(torch.zeros(dim))
            self.register_buffer("_epoch", torch.zeros(1, dtype=torch.long), persistent=False)
        def set_epoch(self, epoch: int): self._epoch[0] = epoch
        def trigger_switch(self):
            if not self.switched:
                print(f"⚡ GatedReLU switched to nonconvex at step {self._epoch.item()}")
                self.switched = True
        def _convex(self, x): return torch.max(self.a * x + self.b, x)
        def _nonconvex(self, x): return torch.min(self.a * x + self.b, x)
        def forward(self, x):
            if self.mode == "convex": return self._convex(x)
            elif self.mode == "nonconvex": return self._nonconvex(x)
            elif self.mode == "twostage":
                if self.switched or (self._epoch.item() >= self.switch_epoch):
                    return self._nonconvex(x)
                else:
                    return self._convex(x)
            else: raise ValueError(f"Unknown mode: {self.mode}")

    class MiniGAT(nn.Module):
        def __init__(self, vocab_size, dim, max_len=32, heads=4, layers=2):
            super().__init__()
            self.emb = nn.Embedding(vocab_size+1, dim)
            self.pos = nn.Embedding(max_len, dim)
            enc = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
            self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
            self.max_len = max_len
        def forward(self, x):
            L = min(x.size(1), self.max_len)
            pos = torch.arange(L, device=x.device).unsqueeze(0).expand(x.size(0), L)
            h = self.emb(x[:, :L]) + self.pos(pos)
            h = self.encoder(h)
            mean, cls = h.mean(dim=1), h[:, 0, :]
            return torch.cat([mean, cls], dim=-1)

    class ImmuneNet(nn.Module):
        """Sequence/TCR/Allele encoder (peptide mode)."""
        def __init__(self, vocab_size, allele_count, dim=128, pep_len=32, tcr_len=24,
                     act_mode="convex"):
            super().__init__()
            self.pep_enc = MiniGAT(vocab_size, dim, pep_len)
            self.tcr_enc = MiniGAT(vocab_size, dim, tcr_len)
            self.all_emb = nn.Embedding(allele_count+1, dim)
            in_dim = 2*dim + 2*dim + dim  # pep(mean+cls) + tcr(mean+cls) + allele
            hid = 256
            def act(): return GatedReLU(hid, act_mode)
            self.backbone = nn.Sequential(nn.Linear(in_dim, hid), act(),
                                          nn.Linear(hid, hid), act())
            self.binding = nn.Linear(hid, 1)
            self.recognition = nn.Linear(hid, 1)
            self.cyt_fc = nn.Linear(len(CYTOKINES), 32)
            self.response = nn.Sequential(nn.Linear(hid+32, 128), nn.ReLU(), nn.Linear(128, 1))
        def encode_backbone(self, pep, tcr, allele):
            pep_h = self.pep_enc(pep)
            tcr_h = self.tcr_enc(tcr)
            all_h = self.all_emb(allele)
            return self.backbone(torch.cat([pep_h, tcr_h, all_h], dim=-1))
        def forward(self, pep, tcr, allele, cytok_onehot):
            z = self.encode_backbone(pep, tcr, allele)
            bind = torch.sigmoid(self.binding(z))
            recog = torch.sigmoid(self.recognition(z))
            c = F.relu(self.cyt_fc(cytok_onehot))
            resp = self.response(torch.cat([z, c], dim=-1))
            return bind, recog, resp

    class VectorNet(nn.Module):
        """Vector feature encoder (PBMC mode)."""
        def __init__(self, in_dim, act_mode="convex"):
            super().__init__()
            hid = 256
            def act(): return GatedReLU(hid, act_mode)
            self.backbone = nn.Sequential(nn.Linear(in_dim, hid), act(),
                                          nn.Linear(hid, hid), act())
            self.binding = nn.Linear(hid, 1)
            self.recognition = nn.Linear(hid, 1)
            self.cyt_fc = nn.Linear(len(CYTOKINES), 32)
            self.response = nn.Sequential(nn.Linear(hid+32, 128), nn.ReLU(), nn.Linear(128, 1))
        def encode_backbone(self, x_vec):
            return self.backbone(x_vec)
        def forward(self, x_vec, cytok_onehot):
            z = self.encode_backbone(x_vec)
            bind = torch.sigmoid(self.binding(z))
            recog = torch.sigmoid(self.recognition(z))
            c = F.relu(self.cyt_fc(cytok_onehot))
            resp = self.response(torch.cat([z, c], dim=-1))
            return bind, recog, resp


# =========================
# PPO: continuous cytokine env (multistep-only)
# =========================
if HAS_TORCH:
    class PPOPolicy(nn.Module):
        def __init__(self, input_dim, num_actions, hidden=256, act_mode="convex", switch_epoch=5):
            super().__init__()
            def act(): return GatedReLU(hidden, act_mode, switch_epoch)
            self.actor = nn.Sequential(nn.Linear(input_dim, hidden), act(), nn.Linear(hidden, num_actions))
            self.critic = nn.Sequential(nn.Linear(input_dim, hidden), act(), nn.Linear(hidden, 1))
            self.act_mode = act_mode; self.switch_epoch = switch_epoch
        def set_epoch(self, epoch: int):
            for m in self.modules():
                if isinstance(m, GatedReLU):
                    m.set_epoch(epoch)
                    if self.act_mode == "twostage" and not m.switched and epoch >= self.switch_epoch:
                        m.trigger_switch()
        def act(self, x):
            mu = self.actor(x)
            dist = torch.distributions.Normal(mu, 0.1)
            a = dist.sample()
            logp = dist.log_prob(a).sum(-1)
            v = self.critic(x).squeeze(-1)
            return torch.tanh(a), logp, v

    @torch.no_grad()
    def extract_state_seq(model: ImmuneNet, pep, tcr, allele):
        return model.encode_backbone(pep, tcr, allele).detach()

    @torch.no_grad()
    def extract_state_vec(model: VectorNet, x_vec):
        return model.encode_backbone(x_vec).detach()

    @torch.no_grad()
    def env_step_seq(model, pep, tcr, allele, action_cont, device,
                     perturb_prob=0.2, action_mag=0.4, max_spike=3, cytok_prev: Optional[torch.Tensor]=None):
        B = pep.size(0); N = len(CYTOKINES)
        if cytok_prev is None:
            cytok = torch.zeros((B, N), device=device); cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        else:
            cytok = cytok_prev.clone()
        for i in range(B):
            if random.random() < perturb_prob:
                k = np.random.randint(1, min(max_spike, N) + 1)
                idx = np.random.choice(N, size=k, replace=False)
                signs = np.random.choice([-1.0, 1.0], size=k)
                cytok[i, idx] += action_mag * torch.tensor(signs, dtype=torch.float32, device=device)
        cytok_next = torch.clamp(cytok + 0.1 * action_cont, 0.0, 1.0)
        _, recog, resp = model(pep, tcr, allele, cytok_next)
        reward = (0.7*resp + 0.3*recog).squeeze(-1).detach()
        return reward, cytok_next

    @torch.no_grad()
    def env_step_vec(model, x_vec, action_cont, device,
                     perturb_prob=0.2, action_mag=0.4, max_spike=3, cytok_prev: Optional[torch.Tensor]=None):
        B = x_vec.size(0); N = len(CYTOKINES)
        if cytok_prev is None:
            cytok = torch.zeros((B, N), device=device); cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        else:
            cytok = cytok_prev.clone()
        for i in range(B):
            if random.random() < perturb_prob:
                k = np.random.randint(1, min(max_spike, N) + 1)
                idx = np.random.choice(N, size=k, replace=False)
                signs = np.random.choice([-1.0, 1.0], size=k)
                cytok[i, idx] += action_mag * torch.tensor(signs, dtype=torch.float32, device=device)
        cytok_next = torch.clamp(cytok + 0.1 * action_cont, 0.0, 1.0)
        _, recog, resp = model(x_vec, cytok_next)
        reward = (0.7*resp + 0.3*recog).squeeze(-1).detach()
        return reward, cytok_next

    def ppo_train_multistep(
        state_dim: int,
        make_state_batch,       # function() -> dict with 'state', 'cytok_init', ... or None when epoch ends
        env_step_fn,            # function(batch, action, device, cytok_prev=None) -> (reward, next_cytok)
        policy_act_mode="convex",
        switch_epoch=5,
        ppo_epochs=3, episodes=5, clip=0.2,
        discount_gamma=0.99, multistep_steps=50,
        lr=5e-4,
        log_step_rewards=False, step_log_writer: Optional[csv.writer]=None, setup_name: str=""
    ):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        policy = PPOPolicy(state_dim, len(CYTOKINES), hidden=256,
                           act_mode=policy_act_mode, switch_epoch=switch_epoch).to(device)
        opt = torch.optim.Adam(policy.parameters(), lr=lr)

        for ep in range(episodes):
            policy.set_epoch(ep)
            S_list, A_list, oldlogp_list, G_list, Adv_list = [], [], [], [], []
            step_rewards_accum = np.zeros(multistep_steps, dtype=np.float64)
            n_batches = 0

            while True:
                batch = make_state_batch()
                if batch is None:
                    break
                z = batch["state"].to(device)
                cytok_init = batch["cytok_init"].to(device)

                with torch.no_grad():
                    gamma = discount_gamma
                    discounted = torch.zeros(z.size(0), device=device)
                    cytok_state = cytok_init
                    per_step_rewards = []
                    for t in range(multistep_steps):
                        a, logp, v = policy.act(z)
                        r_t, cytok_state = env_step_fn(batch, a, device, cytok_prev=cytok_state)
                        discounted += (gamma**t) * r_t
                        per_step_rewards.append(r_t.mean().item())
                    a_last, logp_last, v_last = policy.act(z)
                    G = discounted
                    A = (G - v_last).detach()

                S_list.append(z); A_list.append(a_last); oldlogp_list.append(logp_last); G_list.append(G); Adv_list.append(A)
                step_rewards_accum += np.array(per_step_rewards, dtype=np.float64)
                n_batches += 1

            if not S_list:
                print("No PPO batches produced this episode.")
                continue

            S  = torch.cat(S_list)
            A  = torch.cat(A_list)
            OL = torch.cat(oldlogp_list)
            Gt = torch.cat(G_list)
            Adv= torch.cat(Adv_list)
            Adv = (Adv - Adv.mean()) / (Adv.std() + 1e-8)

            for _ in range(ppo_epochs):
                mu = policy.actor(S)
                dist = torch.distributions.Normal(mu, 0.1)
                new_logp = dist.log_prob(A).sum(-1)
                ratio = torch.exp(new_logp - OL)
                surr1 = ratio * Adv
                surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * Adv
                actor_loss = -torch.min(surr1, surr2).mean()

                v_pred = policy.critic(S).squeeze(-1)
                critic_loss = nonconvex_loss(v_pred, Gt)

                loss = actor_loss + 0.5 * critic_loss
                opt.zero_grad(); loss.backward(); opt.step()

            print(f"[PPO] Episode {ep+1}/{episodes} (multistep={multistep_steps}): return_mean={Gt.mean().item():.4f}")
            if log_step_rewards and step_log_writer is not None and n_batches > 0:
                curve = (step_rewards_accum / max(n_batches,1)).tolist()
                for tstep, rmean in enumerate(curve):
                    step_log_writer.writerow([setup_name, ep+1, tstep+1, rmean])


# =========================
# Datasets + loaders
# =========================
if HAS_TORCH:
    # ----- Peptide dataset (original) -----
    class PeptideDataset(Dataset):
        def __init__(self, examples, allele_to_id, pep_len=32, tcr_len=24):
            self.examples = examples; self.allele_to_id = allele_to_id
            self.tok_p = SeqTokenizer(pep_len); self.tok_t = SeqTokenizer(tcr_len)
        def __len__(self): return len(self.examples)
        def __getitem__(self, idx):
            ex = self.examples[idx]
            pep = self.tok_p.encode(ex.peptide)
            tcr = self.tok_t.encode(ex.tcr or "CASSIRSSYEQYF")
            all_idx = self.allele_to_id.get(ex.allele, 0)
            return pep, tcr, all_idx, float(ex.score)

    def collate_pep(batch):
        pep, tcr, all_idx, y = zip(*batch)
        pep = torch.stack(pep); tcr = torch.stack(tcr)
        all_idx = torch.tensor(all_idx, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
        cytok = torch.zeros((pep.size(0), len(CYTOKINES)), dtype=torch.float32)
        cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        return pep, tcr, all_idx, cytok, y

    @torch.no_grad()
    def eval_loader_metrics_seq(model: ImmuneNet, loader, device):
        model.eval(); preds, targets = [], []
        for pep, tcr, all_idx, cytok, y in loader:
            pep, tcr, all_idx, cytok = pep.to(device), tcr.to(device), all_idx.to(device), cytok.to(device)
            bind, _, _ = model(pep, tcr, all_idx, cytok)
            preds.extend(bind.cpu().numpy()); targets.extend(y.numpy())
        return regression_metrics(np.array(preds), np.array(targets))

    def train_supervised_seq(model: ImmuneNet, train_loader, val_loader, device, epochs=10, lr=1e-3, weight_decay=0.0, betas=(0.9,0.999)):
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
        for ep in range(epochs):
            model.train()
            for pep, tcr, all_idx, cytok, y in train_loader:
                pep, tcr, all_idx, cytok, y = pep.to(device), tcr.to(device), all_idx.to(device), cytok.to(device), y.to(device)
                bind, _, _ = model(pep, tcr, all_idx, cytok)
                loss = nonconvex_loss(bind, y)
                opt.zero_grad(); loss.backward(); opt.step()
            vm = eval_loader_metrics_seq(model, val_loader, device)
            print(f"[Epoch {ep+1}] Val MSE={vm['MSE']:.4f}, R2={vm['R2']:.4f}")

    # ----- PBMC vector dataset -----
    class VectorDataset(Dataset):
        def __init__(self, X: np.ndarray, y: np.ndarray):
            self.X = X.astype(np.float32)
            self.y = y.astype(np.float32).reshape(-1, 1)
        def __len__(self): return self.X.shape[0]
        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]

    def collate_vec(batch):
        X, y = zip(*batch)
        X = torch.tensor(np.stack(X), dtype=torch.float32)
        y = torch.tensor(np.stack(y), dtype=torch.float32)
        cytok = torch.zeros((X.size(0), len(CYTOKINES)), dtype=torch.float32)
        cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        return X, cytok, y

    @torch.no_grad()
    def eval_loader_metrics_vec(model: VectorNet, loader, device):
        model.eval(); preds, targets = [], []
        for X, cytok, y in loader:
            X, cytok = X.to(device), cytok.to(device)
            bind, _, _ = model(X, cytok)
            preds.extend(bind.cpu().numpy()); targets.extend(y.numpy())
        return regression_metrics(np.array(preds), np.array(targets))

    def train_supervised_vec(model: VectorNet, train_loader, val_loader, device, epochs=10, lr=1e-3, weight_decay=0.0, betas=(0.9,0.999)):
        opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=betas)
        for ep in range(epochs):
            model.train()
            for X, cytok, y in train_loader:
                X, cytok, y = X.to(device), cytok.to(device), y.to(device)
                bind, _, _ = model(X, cytok)
                loss = nonconvex_loss(bind, y)
                opt.zero_grad(); loss.backward(); opt.step()
            vm = eval_loader_metrics_vec(model, val_loader, device)
            print(f"[Epoch {ep+1}] Val MSE={vm['MSE']:.4f}, R2={vm['R2']:.4f}")


# =========================
# Oversampling helpers (peptide)
# =========================
if HAS_TORCH:
    def _extract_examples_from_subset(ds_or_subset):
        if isinstance(ds_or_subset, Subset):
            base: PeptideDataset = ds_or_subset.dataset
            return [base.examples[i] for i in ds_or_subset.indices]
        elif isinstance(ds_or_subset, PeptideDataset):
            return ds_or_subset.examples
        else:
            raise TypeError("Unsupported dataset type for oversampling.")

    def build_weighted_sampler(
        ds_or_subset,
        allele_to_id: Dict[str, int],
        by: str = "target_bin",
        n_bins: int = 6,
        min_count_smoothing: float = 1.0
    ):
        exs = _extract_examples_from_subset(ds_or_subset)
        N = len(exs)
        if N == 0: return None
        weights = np.ones(N, dtype=np.float32)
        if by == "allele":
            allele_ids = np.array([allele_to_id.get(ex.allele, 0) for ex in exs])
            uniq, counts = np.unique(allele_ids, return_counts=True)
            freq = {u: c for u, c in zip(uniq, counts)}
            weights = np.array([1.0 / (freq[a] + min_count_smoothing) for a in allele_ids], dtype=np.float32)
        else:
            targets = np.array([ex.score for ex in exs], dtype=np.float32)
            if np.allclose(targets.min(), targets.max()):
                weights = np.ones(N, dtype=np.float32)
            else:
                quantiles = np.linspace(0, 1, num=n_bins+1)
                edges = np.quantile(targets, quantiles)
                edges = np.unique(edges)
                if len(edges) < 3:
                    edges = np.linspace(targets.min(), targets.max(), num=max(3, n_bins+1))
                bins = np.clip(np.digitize(targets, edges[1:-1], right=False), 0, len(edges)-2)
                uniq, counts = np.unique(bins, return_counts=True)
                freq = {u: c for u, c in zip(uniq, counts)}
                weights = np.array([1.0 / (freq[b] + min_count_smoothing) for b in bins], dtype=np.float32)
        weights = weights / weights.sum()
        weights = torch.tensor(weights, dtype=torch.double)
        sampler = WeightedRandomSampler(weights, num_samples=len(exs), replacement=True)
        return sampler


# =========================
# PBMC helpers (optional)
# =========================

def compute_knn_edges(X: np.ndarray, k=3):
    nbrs = NearestNeighbors(n_neighbors=k, metric='cosine').fit(X)
    _, indices = nbrs.kneighbors(X)
    edges = []
    for i in range(X.shape[0]):
        for j in indices[i]:
            if i != j: edges.append([i, j])
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index


class SimpleGAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=64, out_dim=32, heads=2, dropout=0.2):
        super().__init__()
        if not HAS_PYG:
            raise RuntimeError("torch_geometric is required for SimpleGAT.")
        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim*heads, out_dim, heads=1, dropout=dropout)
    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index); x = torch.relu(x)
        x = self.gat2(x, edge_index)
        return x


def pbmc_build_embeddings(sample_size=10000, seed=42):
    if not HAS_MUON:
        raise RuntimeError("muon/mudatasets is required for --data_mode pbmc.")
    np.random.seed(seed)
    mdata = mds.load("pbmc10k_multiome", full=True)
    mdata.var_names_make_unique()
    rna = mdata.mod['rna'].copy()
    atac = mdata.mod['atac'].copy()
    adt = mdata.mod['adt'].copy() if 'adt' in mdata.mod else None
    common = rna.obs_names.intersection(atac.obs_names)
    if adt is not None:
        common = common.intersection(adt.obs_names)
    common = np.array(common)
    if sample_size < len(common):
        sel = np.random.choice(common, size=sample_size, replace=False)
    else:
        sel = common
    rna = rna[sel]; atac = atac[sel];
    if adt is not None: adt = adt[sel]
    rna_X = rna.X.toarray() if hasattr(rna.X, "toarray") else np.asarray(rna.X)
    atac_X = atac.X.toarray() if hasattr(atac.X, "toarray") else np.asarray(atac.X)
    adt_X = adt.X.toarray() if adt is not None and hasattr(adt.X, "toarray") else (adt.X if adt is not None else None)

    def make_emb(X):
        if X is None: return None
        if HAS_PYG:
            X_t = torch.tensor(X, dtype=torch.float32)
            edge_index = compute_knn_edges(X, k=3)
            model = SimpleGAT(in_dim=X.shape[1])
            with torch.no_grad():
                Z = model(X_t, edge_index).cpu().numpy()
            return Z
        else:
            return X  # fallback: raw features if PyG not available

    z_rna = make_emb(rna_X)
    z_atac = make_emb(atac_X)
    z_adt = make_emb(adt_X) if adt_X is not None else None

    scalers = {}
    expr_train, expr_test, pseudo_train, pseudo_test = {}, {}, {}, {}
    mods = {"rna": z_rna, "atac": z_atac}
    if z_adt is not None: mods["adt"] = z_adt
    for mod, Z in mods.items():
        scaler = StandardScaler()
        Zs = scaler.fit_transform(Z); scalers[mod] = scaler
        idx = np.arange(Zs.shape[0])
        tr, te = train_test_split(idx, test_size=0.2, random_state=seed)
        expr_train[mod] = Zs[tr]; expr_test[mod] = Zs[te]
        pseudo = np.linspace(0.0, 1.0, Zs.shape[0]).astype(np.float32)
        pseudo_train[mod] = pseudo[tr]; pseudo_test[mod] = pseudo[te]

    fused_train = np.concatenate([expr_train[m] for m in expr_train.keys()], axis=1)
    fused_test  = np.concatenate([expr_test[m]  for m in expr_test.keys()],  axis=1)
    pt_train = np.mean(np.stack([pseudo_train[m] for m in pseudo_train.keys()], axis=1), axis=1)
    pt_test  = np.mean(np.stack([pseudo_test[m]  for m in pseudo_test.keys()],  axis=1), axis=1)

    return fused_train, pt_train, fused_test, pt_test


# =========================
# Adaptive per-gene threshold (optional hook)
# =========================
class PerGeneAdaptiveThreshold:
    def __init__(self, modality_dims: Dict[str,int], alpha=0.1):
        self.thresholds = {mod: {i: 0.0 for i in range(dim)} for mod, dim in modality_dims.items()}
        self.alpha = alpha
    def update(self, gene_rewards: Dict[str, Dict[int, float]]):
        for mod, rewards in gene_rewards.items():
            for gid, r in rewards.items():
                if r is None or (isinstance(r, float) and np.isnan(r)): continue
                self.thresholds[mod][gid] = self.alpha*float(r) + (1-self.alpha)*self.thresholds[mod].get(gid,0.0)
    def get(self, mod, gene_id):
        return float(self.thresholds.get(mod, {}).get(gene_id,0.0))


# =========================
# Main
# =========================

def main():
    import sys
    if any(a.startswith("-f") or "kernel" in a for a in sys.argv):
        sys.argv = [sys.argv[0]]

    p = argparse.ArgumentParser()
    # --- Data mode ---
    p.add_argument('--data_mode', type=str, default='peptide', choices=['peptide','pbmc'],
                   help='peptide: original peptide/TCR/allele; pbmc: PBMC10k multiome fused vectors')
    # --- Peptide paths ---
    p.add_argument('--train', default='train_BA1.txt')
    p.add_argument('--test', default='test_BA1.txt')
    p.add_argument('--alleles', default='allelelist.txt')
    # --- Common training ---
    p.add_argument('--epochs', type=int, default=8)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--ppo_episodes', type=int, default=5)
    p.add_argument('--ppo_epochs', type=int, default=10)
    # --- Sequence lengths (peptide mode) ---
    p.add_argument('--pep_len', type=int, default=32)
    p.add_argument('--tcr_len', type=int, default=24)
    # --- Immune/VectorNet activation (fixed) ---
    p.add_argument('--immune_act_mode', type=str, default='convex', choices=['convex','nonconvex'],
                   help='Fixed activation for base net across all PPO runs.')
    # --- PPO activation modes (varies) ---
    p.add_argument('--ppo_switch_epoch', type=int, default=4,
                   help='Episode at which PPO two-stage switches to nonconvex.')
    p.add_argument('--ppo_modes', type=str, default='convex,nonconvex,twostage',
                   help='Comma-separated PPO activation modes to compare.')
    # --- Cytokine env params ---
    p.add_argument('--cytokine_perturb_prob', type=float, default=0.8,
                   help='Prob. to randomly spike cytokines in env_step.')
    p.add_argument('--cytokine_action_mag', type=float, default=0.8,
                   help='Magnitude of random spike added when perturbing.')
    p.add_argument('--cytokine_max_spike', type=int, default=5,
                   help='Max cytokines to spike per perturbation (capped by N).')
    # --- Oversampling (peptide only) ---
    p.add_argument('--oversample', action='store_true',
                   help='Enable inverse-frequency oversampling on the TRAIN subset (peptide mode).')
    p.add_argument('--oversample_bins', type=int, default=6,
                   help='Number of bins for target_bin oversampling.')
    p.add_argument('--oversample_by', type=str, default='allele', choices=['target_bin','allele'],
                   help='Oversample strategy: target_bin or allele.')
    # --- Multistep PPO params ---
    p.add_argument('--multistep_steps', type=int, default=5,
                   help='Number of steps per PPO rollout (multistep only).')
    p.add_argument('--discount_gamma', type=float, default=0.99,
                   help='Discount factor for multistep returns.')
    p.add_argument('--log_step_rewards', action='store_true',
                   help='Write per-episode per-step reward means to ppo_multistep_rewards.csv')
    # --- Optimizer knobs ---
    p.add_argument('--base_lr', type=float, default=1e-3)
    p.add_argument('--base_wd', type=float, default=0.0)
    p.add_argument('--base_b1', type=float, default=0.9)
    p.add_argument('--base_b2', type=float, default=0.999)
    # --- PBMC controls ---
    p.add_argument('--pbmc_sample', type=int, default=10000,
                   help='PBMC sample size (if available cells fewer, uses all).')
    p.add_argument('--val_ratio', type=float, default=0.2)
    # --- Failure presets to induce local optima ---
    p.add_argument('--failure_preset', type=str, default='none', choices=['none','convex_flat','nonconvex_sharp'],
                   help='Set hyperparameters to encourage specific local-optima failures.')

    a, _ = p.parse_known_args()
    seed_everything(1)

    if not HAS_TORCH:
        print("⚠️ PyTorch not available."); return
    d = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Apply failure presets (affects batch size, env noise, and base optimizer)
    if a.failure_preset == 'convex_flat':
        a.batch_size = max(128, a.batch_size)
        a.cytokine_perturb_prob = 0.1
        a.cytokine_action_mag = 0.2
        a.base_lr = min(a.base_lr, 1e-4)
        a.base_wd = max(a.base_wd, 1e-5)
        a.base_b1, a.base_b2 = 0.8, 0.99
        print("[Preset] convex_flat applied: large batch, low noise, low LR, higher WD")
    elif a.failure_preset == 'nonconvex_sharp':
        a.batch_size = min(32, a.batch_size)
        a.cytokine_perturb_prob = 0.9
        a.cytokine_action_mag = 0.9
        a.cytokine_max_spike = max(a.cytokine_max_spike, 10)
        a.base_lr = max(a.base_lr, 1e-4)
        a.base_wd = 0.0
        a.base_b1, a.base_b2 = 0.9, 0.999
        print("[Preset] nonconvex_sharp applied: small batch, high noise, higher LR, no WD")

    # Step reward logging (optional)
    step_log_writer = None
    step_log_file = None
    if a.log_step_rewards:
        step_log_path = "ppo_multistep_rewards.csv"
        new_file = not os.path.exists(step_log_path)
        step_log_file = open(step_log_path, "a", newline="")
        step_log_writer = csv.writer(step_log_file)
        if new_file:
            step_log_writer.writerow(["Setup","Episode","Step","RewardMean"])

    results = []
    improvements = []

    if a.data_mode == "peptide":
        # ======== PEPTIDE PIPELINE ========
        allele_to_id = load_alleles(a.alleles)
        tr = parse_examples(a.train, allele_to_id)
        ts = parse_examples(a.test, allele_to_id)
        full_train_ds = PeptideDataset(tr, allele_to_id, a.pep_len, a.tcr_len)
        N = len(full_train_ds)
        idx = list(range(N)); random.shuffle(idx)
        v = int(max(1, round(a.val_ratio * N)))
        tr_ds = Subset(full_train_ds, idx[v:])
        va_ds = Subset(full_train_ds, idx[:v])
        te_ds = PeptideDataset(ts, allele_to_id, a.pep_len, a.tcr_len)

        if a.oversample:
            sampler = build_weighted_sampler(tr_ds, allele_to_id=allele_to_id, by=a.oversample_by, n_bins=a.oversample_bins)
            trdl = DataLoader(tr_ds, batch_size=a.batch_size, sampler=sampler, collate_fn=collate_pep)
        else:
            trdl = DataLoader(tr_ds, batch_size=a.batch_size, shuffle=True, collate_fn=collate_pep)

        vadl = DataLoader(va_ds, batch_size=a.batch_size, shuffle=False, collate_fn=collate_pep)
        tsdl = DataLoader(te_ds, batch_size=a.batch_size, shuffle=False, collate_fn=collate_pep)

        print(f"\n=== Training ImmuneNet (fixed {a.immune_act_mode}) ===")
        base_model = ImmuneNet(len(AA_VOCAB), len(allele_to_id),
                               act_mode=a.immune_act_mode,
                               pep_len=a.pep_len, tcr_len=a.tcr_len).to(d)
        train_supervised_seq(base_model, trdl, vadl, d, epochs=a.epochs,
                             lr=a.base_lr, weight_decay=a.base_wd, betas=(a.base_b1,a.base_b2))
        base_state = deepcopy(base_model.state_dict())

        before = eval_loader_metrics_seq(base_model, tsdl, d)
        print("[Before PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in before.items()]))

        ppo_modes = [m.strip() for m in a.ppo_modes.split(",") if m.strip()]
        for ppo_mode in ppo_modes:
            print(f"\n=== PPO Activation: {ppo_mode.upper()} (multistep={a.multistep_steps}) ===")
            model = ImmuneNet(len(AA_VOCAB), len(allele_to_id),
                              act_mode=a.immune_act_mode,
                              pep_len=a.pep_len, tcr_len=a.tcr_len).to(d)
            model.load_state_dict(base_state)

            # Fresh iterator per mode
            def make_state_batch():
                if not hasattr(make_state_batch, "iter"):
                    make_state_batch.iter = iter(trdl)
                try:
                    pep, tcr, all_idx, cytok_init, y = next(make_state_batch.iter)
                except StopIteration:
                    make_state_batch.iter = iter(trdl)
                    return None
                with torch.no_grad():
                    z = extract_state_seq(model, pep.to(d), tcr.to(d), all_idx.to(d))
                return {"state": z, "cytok_init": cytok_init, "pep": pep, "tcr": tcr, "all": all_idx}

            def env_step_wrapper(batch, action, device, cytok_prev=None):
                pep = batch["pep"].to(device); tcr = batch["tcr"].to(device); all_idx = batch["all"].to(device)
                return env_step_seq(model, pep, tcr, all_idx, action, device,
                                    perturb_prob=a.cytokine_perturb_prob,
                                    action_mag=a.cytokine_action_mag,
                                    max_spike=a.cytokine_max_spike,
                                    cytok_prev=cytok_prev)

            setup_name = f"Immune({a.immune_act_mode})+PPO({ppo_mode})"
            ppo_train_multistep(
                state_dim=256,
                make_state_batch=make_state_batch,
                env_step_fn=env_step_wrapper,
                policy_act_mode=ppo_mode,
                switch_epoch=a.ppo_switch_epoch,
                ppo_epochs=a.ppo_epochs, episodes=a.ppo_episodes,
                discount_gamma=a.discount_gamma, multistep_steps=a.multistep_steps,
                log_step_rewards=a.log_step_rewards, step_log_writer=step_log_writer, setup_name=setup_name
            )
            after = eval_loader_metrics_seq(model, tsdl, d)
            print("[After  PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in after.items()]))
            results.append([setup_name, after["MSE"], after["RMSE"], after["MAE"], after["R2"], after["Pearson"]])

            # --- Improvement tracking ---
            imp = {
                "Setup": setup_name,
                "ΔMSE": before["MSE"] - after["MSE"],
                "ΔR2":  after["R2"] - before["R2"],
                "ΔPearson": after["Pearson"] - before["Pearson"]
            }
            improvements.append(imp)
            print(f"[Improvement] {setup_name}: ΔMSE={imp['ΔMSE']:.4f}, ΔR2={imp['ΔR2']:.4f}, ΔPearson={imp['ΔPearson']:.4f}")

    else:
        # ======== PBMC PIPELINE ========
        if not HAS_MUON:
            print("❌ --data_mode pbmc requires muon + mudatasets installed.")
            return
        fused_train, y_train, fused_test, y_test = pbmc_build_embeddings(sample_size=a.pbmc_sample, seed=1)

        # build datasets
        tr_ds = VectorDataset(fused_train, y_train)
        te_ds = VectorDataset(fused_test, y_test)
        idx = np.arange(len(tr_ds)); np.random.shuffle(idx)
        v = int(max(1, round(a.val_ratio * len(idx))))
        va_idx, tr_idx = idx[:v], idx[v:]
        va_ds = torch.utils.data.Subset(tr_ds, va_idx)
        tr_sub = torch.utils.data.Subset(tr_ds, tr_idx)

        trdl = DataLoader(tr_sub, batch_size=a.batch_size, shuffle=True,  collate_fn=collate_vec)
        vadl = DataLoader(va_ds,  batch_size=a.batch_size, shuffle=False, collate_fn=collate_vec)
        tsdl = DataLoader(te_ds,  batch_size=a.batch_size, shuffle=False, collate_fn=collate_vec)

        in_dim = fused_train.shape[1]
        print(f"\n=== Training VectorNet (fixed {a.immune_act_mode}) on PBMC fused dim={in_dim} ===")
        base_model = VectorNet(in_dim=in_dim, act_mode=a.immune_act_mode).to(d)
        train_supervised_vec(base_model, trdl, vadl, d, epochs=a.epochs,
                             lr=a.base_lr, weight_decay=a.base_wd, betas=(a.base_b1,a.base_b2))
        base_state = deepcopy(base_model.state_dict())

        before = eval_loader_metrics_vec(base_model, tsdl, d)
        print("[Before PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in before.items()]))

        def env_vec(batch, action, device, cytok_prev=None):
            X = batch["X"].to(device)
            return env_step_vec(model, X, action, device,
                                perturb_prob=a.cytokine_perturb_prob,
                                action_mag=a.cytokine_action_mag,
                                max_spike=a.cytokine_max_spike,
                                cytok_prev=cytok_prev)

        ppo_modes = [m.strip() for m in a.ppo_modes.split(",") if m.strip()]
        for ppo_mode in ppo_modes:
            print(f"\n=== PPO Activation: {ppo_mode.upper()} (multistep={a.multistep_steps}) ===")
            model = VectorNet(in_dim=in_dim, act_mode=a.immune_act_mode).to(d)
            model.load_state_dict(base_state)

            def make_state_batch():
                if not hasattr(make_state_batch, "iter"):
                    make_state_batch.iter = iter(trdl)
                try:
                    X, cytok_init, y = next(make_state_batch.iter)
                except StopIteration:
                    make_state_batch.iter = iter(trdl)
                    return None
                with torch.no_grad():
                    z = extract_state_vec(model, X.to(d))
                return {"state": z, "cytok_init": cytok_init, "X": X}

            def env_step_wrapper(batch, action, device, cytok_prev=None):
                return env_vec(batch, action, device, cytok_prev)

            setup_name = f"Vector({a.immune_act_mode})+PPO({ppo_mode})"
            ppo_train_multistep(
                state_dim=256,
                make_state_batch=make_state_batch,
                env_step_fn=env_step_wrapper,
                policy_act_mode=ppo_mode,
                switch_epoch=a.ppo_switch_epoch,
                ppo_epochs=a.ppo_epochs, episodes=a.ppo_episodes,
                discount_gamma=a.discount_gamma, multistep_steps=a.multistep_steps,
                log_step_rewards=a.log_step_rewards, step_log_writer=step_log_writer, setup_name=setup_name
            )
            after = eval_loader_metrics_vec(model, tsdl, d)
            print("[After  PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in after.items()]))
            results.append([setup_name, after["MSE"], after["RMSE"], after["MAE"], after["R2"], after["Pearson"]])

            imp = {
                "Setup": setup_name,
                "ΔMSE": before["MSE"] - after["MSE"],
                "ΔR2":  after["R2"] - before["R2"],
                "ΔPearson": after["Pearson"] - before["Pearson"]
            }
            improvements.append(imp)
            print(f"[Improvement] {setup_name}: ΔMSE={imp['ΔMSE']:.4f}, ΔR2={imp['ΔR2']:.4f}, ΔPearson={imp['ΔPearson']:.4f}")

    # --- Save results ---
    with open("gated_relu_multistep.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Setup", "MSE", "RMSE", "MAE", "R2", "Pearson"])
        writer.writerows(results)
    print("\n✅ Results saved to gated_relu_multistep.csv")

    with open("performance_improvements.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Setup", "ΔMSE", "ΔR2", "ΔPearson"])
        for imp in improvements:
            writer.writerow([imp["Setup"], imp["ΔMSE"], imp["ΔR2"], imp["ΔPearson"]])
    print("📈 Improvement summary saved to performance_improvements.csv")

    # --- Optional plot (if matplotlib present) ---
    try:
        import matplotlib.pyplot as plt
        import pandas as pd
        if os.path.exists("performance_improvements.csv"):
            df = pd.read_csv("performance_improvements.csv")
            plt.figure(figsize=(8,4))
            for metric in ["ΔMSE", "ΔR2", "ΔPearson"]:
                plt.plot(df["Setup"], df[metric], marker='o', label=metric)
            plt.xticks(rotation=30, ha='right')
            plt.ylabel("Improvement (After - Before)")
            plt.title("Performance Improvement per Method")
            plt.legend()
            plt.tight_layout()
            plt.savefig("performance_improvement_plot.png", dpi=160)
            print("📊 Saved performance plot: performance_improvement_plot.png")
    except Exception as e:
        print(f"(Plotting skipped) {e}")

    # --- Best setup by ΔR2 ---
    if len(improvements) > 0:
        best = max(improvements, key=lambda x: x["ΔR2"])
        print(f"🏆 Best R2 improvement: {best['Setup']} (ΔR2={best['ΔR2']:.4f})")

    # Close step log file if used
    try:
        if step_log_writer is not None:
            step_log_file.close()
            print("✅ Per-step curves saved to ppo_multistep_rewards.csv")
    except Exception:
        pass


if __name__ == "__main__":
    try:
        main()
    except SystemExit:
        pass



=== Training ImmuneNet (fixed convex) ===
[Epoch 1] Val MSE=0.0745, R2=-0.1134
[Epoch 2] Val MSE=0.0640, R2=0.0430
[Epoch 3] Val MSE=0.0638, R2=0.0462
[Epoch 4] Val MSE=0.0622, R2=0.0707
