In [None]:
"""
Immune RL with PPO + Gated ReLU (Fixed ImmuneNet Initialization + PPO-Specific Activation)
------------------------------------------------------------------------------------------
- ImmuneNet is trained once with a fixed activation (convex or nonconvex).
- Its trained weights are reused identically across all PPO runs.
- PPO actor/critic activation varies per run: convex, nonconvex, or twostage.
- Nonconvex loss is used for ImmuneNet and PPO critic.
- Results saved to gated_relu_comparison.csv
"""

from __future__ import annotations
import csv, argparse, random
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
from scipy.stats import pearsonr
from copy import deepcopy

# =========================
# Optional PyTorch
# =========================
HAS_TORCH = True
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader, Subset
except Exception:
    HAS_TORCH = False

# =========================
# Constants & utils
# =========================
AA_VOCAB = list("ACDEFGHIKLMNPQRSTVWY")
AA_TO_ID = {a: i + 1 for i, a in enumerate(AA_VOCAB)}
CYTOKINES = ["NONE", "IL2", "IFNG", "IL10", "TNFA"]
CYTOKINE_TO_ID = {c: i for i, c in enumerate(CYTOKINES)}

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    if HAS_TORCH:
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

@dataclass
class Example:
    peptide: str
    allele: str
    score: float
    motif: Optional[str] = None
    tcr: Optional[str] = None

# =========================
# I/O + Tokenizer
# =========================
def smart_read_table(path: str) -> List[List[str]]:
    rows = []
    with open(path, "r", newline="") as f:
        sample = f.read(2048); f.seek(0)
        import csv as _csv
        try:
            dialect = _csv.Sniffer().sniff(sample, delimiters="\t,;")
        except Exception:
            class dialect: delimiter = ","
        reader = _csv.reader(f, dialect)
        for row in reader:
            if row: rows.append([c.strip() for c in row])
    return rows

def load_alleles(path: str) -> Dict[str, int]:
    uniq = []
    for r in smart_read_table(path):
        for c in r:
            for token in c.replace(",", " ").split():
                if token and token not in uniq:
                    uniq.append(token)
    return {a: i for i, a in enumerate(sorted(uniq))}

def parse_examples(path: str, allele_to_id: Dict[str, int]) -> List[Example]:
    rows = smart_read_table(path); exs = []
    for r in rows:
        if len(r) < 3: continue
        pep, score_str, allele = r[0], r[1], r[2]
        try:
            score = float(score_str)
        except Exception:
            continue
        if allele not in allele_to_id:
            allele_to_id[allele] = len(allele_to_id)
        exs.append(Example(peptide=pep, allele=allele, score=score))
    return exs

class SeqTokenizer:
    def __init__(self, max_len=32): self.max_len = max_len
    def encode_ids(self, s: str):
        ids = [AA_TO_ID.get(ch, 0) for ch in s[:self.max_len]]
        if len(ids) < self.max_len: ids += [0]*(self.max_len - len(ids))
        return ids
    def encode(self, s: str):
        arr = np.asarray(self.encode_ids(s), dtype=np.int64)
        return torch.tensor(arr, dtype=torch.long) if HAS_TORCH else arr

# =========================
# Loss + Metrics
# =========================
def nonconvex_loss(pred, target, eps: float = 1e-6):
    e = pred - target
    return torch.mean(torch.sqrt(torch.abs(e) + eps))

def regression_metrics(preds: np.ndarray, targets: np.ndarray) -> Dict[str, float]:
    preds = np.asarray(preds).flatten()
    targets = np.asarray(targets).flatten()
    mse = float(np.mean((preds - targets) ** 2))
    rmse = float(np.sqrt(mse))
    mae = float(np.mean(np.abs(preds - targets)))
    denom = float(np.sum((targets - np.mean(targets)) ** 2) + 1e-12)
    r2 = float(1.0 - np.sum((targets - preds) ** 2) / denom)
    pear = float(pearsonr(preds, targets)[0]) if len(preds) > 1 and np.std(preds) > 0 and np.std(targets) > 0 else 0.0
    return {"MSE": mse, "RMSE": rmse, "MAE": mae, "R2": r2, "Pearson": pear}

# =========================
# Gated ReLU + Models
# =========================
if HAS_TORCH:
    class GatedReLU(nn.Module):
        def __init__(self, dim: int, mode="convex", switch_epoch=5):
            super().__init__()
            self.mode = mode
            self.switch_epoch = switch_epoch
            self.switched = False
            self.a = nn.Parameter(torch.ones(dim))
            self.b = nn.Parameter(torch.zeros(dim))
            self.register_buffer("_epoch", torch.zeros(1, dtype=torch.long), persistent=False)
        def set_epoch(self, epoch: int): self._epoch[0] = epoch
        def trigger_switch(self):
            if not self.switched:
                print(f"⚡ GatedReLU switched to nonconvex at episode {self._epoch.item()}")
                self.switched = True
        def _convex(self, x): return torch.max(self.a * x + self.b, x)
        def _nonconvex(self, x): return torch.min(self.a * x + self.b, x)
        def forward(self, x):
            if self.mode == "convex": return self._convex(x)
            elif self.mode == "nonconvex": return self._nonconvex(x)
            elif self.mode == "twostage":
                if self.switched or (self._epoch.item() >= self.switch_epoch):
                    return self._nonconvex(x)
                else:
                    return self._convex(x)
            else: raise ValueError(f"Unknown mode: {self.mode}")

    class MiniGAT(nn.Module):
        def __init__(self, vocab_size, dim, max_len=32, heads=4, layers=2):
            super().__init__()
            self.emb = nn.Embedding(vocab_size+1, dim)
            self.pos = nn.Embedding(max_len, dim)
            enc = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
            self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
            self.max_len = max_len
        def forward(self, x):
            L = min(x.size(1), self.max_len)
            pos = torch.arange(L, device=x.device).unsqueeze(0).expand(x.size(0), L)
            h = self.emb(x[:, :L]) + self.pos(pos)
            h = self.encoder(h)
            mean, cls = h.mean(dim=1), h[:, 0, :]
            return torch.cat([mean, cls], dim=-1)

    class ImmuneNet(nn.Module):
        def __init__(self, vocab_size, allele_count, dim=128, pep_len=32, tcr_len=24,
                     act_mode="convex"):
            super().__init__()
            self.pep_enc = MiniGAT(vocab_size, dim, pep_len)
            self.tcr_enc = MiniGAT(vocab_size, dim, tcr_len)
            self.all_emb = nn.Embedding(allele_count+1, dim)
            in_dim = 2*dim + 2*dim + dim
            hid = 256
            def act(): return GatedReLU(hid, act_mode)
            self.backbone = nn.Sequential(nn.Linear(in_dim, hid), act(),
                                          nn.Linear(hid, hid), act())
            self.binding = nn.Linear(hid, 1)
            self.recognition = nn.Linear(hid, 1)
            self.cyt_fc = nn.Linear(len(CYTOKINES), 32)
            self.response = nn.Sequential(nn.Linear(hid+32, 128), nn.ReLU(), nn.Linear(128, 1))
        def encode_backbone(self, pep, tcr, allele):
            pep_h = self.pep_enc(pep)
            tcr_h = self.tcr_enc(tcr)
            all_h = self.all_emb(allele)
            return self.backbone(torch.cat([pep_h, tcr_h, all_h], dim=-1))
        def forward(self, pep, tcr, allele, cytok_onehot):
            z = self.encode_backbone(pep, tcr, allele)
            bind = torch.sigmoid(self.binding(z))
            recog = torch.sigmoid(self.recognition(z))
            c = F.relu(self.cyt_fc(cytok_onehot))
            resp = self.response(torch.cat([z, c], dim=-1))
            return bind, recog, resp

# =========================
# PPO
# =========================
if HAS_TORCH:
    class PPOPolicy(nn.Module):
        def __init__(self, input_dim, num_actions, hidden=256, act_mode="convex", switch_epoch=5):
            super().__init__()
            def act(): return GatedReLU(hidden, act_mode, switch_epoch)
            self.actor = nn.Sequential(nn.Linear(input_dim, hidden), act(), nn.Linear(hidden, num_actions))
            self.critic = nn.Sequential(nn.Linear(input_dim, hidden), act(), nn.Linear(hidden, 1))
            self.act_mode = act_mode; self.switch_epoch = switch_epoch
        def set_epoch(self, epoch: int):
            for m in self.modules():
                if isinstance(m, GatedReLU):
                    m.set_epoch(epoch)
                    if self.act_mode == "twostage" and not m.switched and epoch >= self.switch_epoch:
                        m.trigger_switch()
        def act(self, x):
            logits = self.actor(x)
            dist = torch.distributions.Categorical(logits=logits)
            a = dist.sample(); logp = dist.log_prob(a)
            v = self.critic(x).squeeze(-1)
            return a, logp, v

    @torch.no_grad()
    def extract_state(model, pep, tcr, allele):
        return model.encode_backbone(pep, tcr, allele).detach()

    @torch.no_grad()
    def env_step(model, pep, tcr, allele, act, device):
        B = pep.size(0)
        cytok = torch.zeros((B, len(CYTOKINES)), device=device)
        cytok[torch.arange(B), act] = 1.0
        _, recog, resp = model(pep, tcr, allele, cytok)
        return (0.7*resp + 0.3*recog).squeeze(-1).detach()

    def ppo_train(model, loader, device, ppo_epochs=3, episodes=5, clip=0.2, act_mode="convex", switch_epoch=5):
        policy = PPOPolicy(256, len(CYTOKINES), hidden=256, act_mode=act_mode, switch_epoch=switch_epoch).to(device)
        opt = torch.optim.Adam(policy.parameters(), lr=5e-4)
        for ep in range(episodes):
            policy.set_epoch(ep)
            states, actions, old_logp, returns, adv = [], [], [], [], []
            for pep, tcr, all_idx, cytok, y in loader:
                pep, tcr, all_idx = pep.to(device), tcr.to(device), all_idx.to(device)
                with torch.no_grad():
                    z = extract_state(model, pep, tcr, all_idx)
                    a, logp, v = policy.act(z)
                    G = env_step(model, pep, tcr, all_idx, a, device)
                    A = (G - v).detach()
                states.append(z); actions.append(a); old_logp.append(logp); returns.append(G); adv.append(A)
            states, actions = torch.cat(states), torch.cat(actions)
            old_logp, returns, adv = torch.cat(old_logp), torch.cat(returns), torch.cat(adv)
            adv = (adv - adv.mean()) / (adv.std() + 1e-8)
            for _ in range(ppo_epochs):
                logits = policy.actor(states)
                dist = torch.distributions.Categorical(logits=logits)
                new_logp = dist.log_prob(actions)
                ratio = torch.exp(new_logp - old_logp)
                surr1 = ratio * adv
                surr2 = torch.clamp(ratio, 1.0 - clip, 1.0 + clip) * adv
                actor_loss = -torch.min(surr1, surr2).mean()
                critic_loss = nonconvex_loss(policy.critic(states).squeeze(-1), returns)
                loss = actor_loss + 0.5*critic_loss
                opt.zero_grad(); loss.backward(); opt.step()
            print(f"[PPO] Episode {ep+1}/{episodes} — Return={returns.mean().item():.4f}")

# =========================
# Dataset + Training + Eval
# =========================
if HAS_TORCH:
    class PeptideDataset(Dataset):
        def __init__(self, examples, allele_to_id, pep_len=32, tcr_len=24):
            self.examples = examples; self.allele_to_id = allele_to_id
            self.tok_p = SeqTokenizer(pep_len); self.tok_t = SeqTokenizer(tcr_len)
        def __len__(self): return len(self.examples)
        def __getitem__(self, idx):
            ex = self.examples[idx]
            pep = self.tok_p.encode(ex.peptide)
            tcr = self.tok_t.encode(ex.tcr or "CASSIRSSYEQYF")
            all_idx = self.allele_to_id.get(ex.allele, 0)
            return pep, tcr, all_idx, float(ex.score)

    def collate(batch):
        pep, tcr, all_idx, y = zip(*batch)
        pep = torch.stack(pep); tcr = torch.stack(tcr)
        all_idx = torch.tensor(all_idx, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
        cytok = torch.zeros((pep.size(0), len(CYTOKINES)), dtype=torch.float32)
        cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        return pep, tcr, all_idx, cytok, y

    @torch.no_grad()
    def eval_loader_metrics(model, loader, device):
        model.eval(); preds, targets = [], []
        for pep, tcr, all_idx, cytok, y in loader:
            pep, tcr, all_idx, cytok = pep.to(device), tcr.to(device), all_idx.to(device), cytok.to(device)
            bind, _, _ = model(pep, tcr, all_idx, cytok)
            preds.extend(bind.cpu().numpy()); targets.extend(y.numpy())
        return regression_metrics(np.array(preds), np.array(targets))

    def train_supervised(model, train_loader, val_loader, device, epochs=10):
        opt = torch.optim.Adam(model.parameters(), lr=1e-3)
        for ep in range(epochs):
            model.train()
            for pep, tcr, all_idx, cytok, y in train_loader:
                pep, tcr, all_idx, cytok, y = pep.to(device), tcr.to(device), all_idx.to(device), cytok.to(device), y.to(device)
                bind, _, _ = model(pep, tcr, all_idx, cytok)
                loss = nonconvex_loss(bind, y)
                opt.zero_grad(); loss.backward(); opt.step()
            vm = eval_loader_metrics(model, val_loader, device)
            print(f"[Epoch {ep+1}] Val MSE={vm['MSE']:.4f}, R2={vm['R2']:.4f}")

# =========================
# Main
# =========================
def main():
    import sys
    if any(a.startswith("-f") or "kernel" in a for a in sys.argv):
        sys.argv = [sys.argv[0]]

    p = argparse.ArgumentParser()
    p.add_argument('--train', default='train_BA1.txt')
    p.add_argument('--test', default='test_BA1.txt')
    p.add_argument('--alleles', default='allelelist.txt')
    p.add_argument('--epochs', type=int, default=8)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--ppo_episodes', type=int, default=20)
    p.add_argument('--ppo_epochs', type=int, default=3)
    p.add_argument('--pep_len', type=int, default=32)
    p.add_argument('--tcr_len', type=int, default=24)
    p.add_argument('--immune_act_mode', type=str, default='convex', choices=['convex','nonconvex'])
    p.add_argument('--ppo_switch_epoch', type=int, default=4)
    p.add_argument('--val_ratio', type=float, default=0.2)
    a, _ = p.parse_known_args()

    seed_everything(1)
    if not HAS_TORCH:
        print("⚠️ PyTorch not available."); return

    d = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    allele_to_id = load_alleles(a.alleles)
    tr = parse_examples(a.train, allele_to_id)
    ts = parse_examples(a.test, allele_to_id)

    full_train_ds = PeptideDataset(tr, allele_to_id, a.pep_len, a.tcr_len)
    N = len(full_train_ds)
    idx = list(range(N)); random.shuffle(idx)
    v = int(max(1, round(a.val_ratio * N)))
    tr_ds = Subset(full_train_ds, idx[v:])
    va_ds = Subset(full_train_ds, idx[:v])
    te_ds = PeptideDataset(ts, allele_to_id, a.pep_len, a.tcr_len)

    trdl = DataLoader(tr_ds, batch_size=a.batch_size, shuffle=True, collate_fn=collate)
    vadl = DataLoader(va_ds, batch_size=a.batch_size, shuffle=False, collate_fn=collate)
    tsdl = DataLoader(te_ds, batch_size=a.batch_size, shuffle=False, collate_fn=collate)

    # 1️⃣ Train ImmuneNet once
    print(f"\n=== Training ImmuneNet (fixed {a.immune_act_mode}) ===")
    base_model = ImmuneNet(len(AA_VOCAB), len(allele_to_id),
                           act_mode=a.immune_act_mode,
                           pep_len=a.pep_len, tcr_len=a.tcr_len).to(d)
    train_supervised(base_model, trdl, vadl, d, epochs=a.epochs)
    base_state = deepcopy(base_model.state_dict())  # fixed initialization snapshot

    before = eval_loader_metrics(base_model, tsdl, d)
    print("[Before PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in before.items()]))

    # 2️⃣ PPO runs with different activation modes
    ppo_modes = ["convex", "nonconvex", "twostage"]
    results = []

    for ppo_mode in ppo_modes:
        print(f"\n=== PPO Activation: {ppo_mode.upper()} ===")
        model = ImmuneNet(len(AA_VOCAB), len(allele_to_id),
                          act_mode=a.immune_act_mode,
                          pep_len=a.pep_len, tcr_len=a.tcr_len).to(d)
        model.load_state_dict(base_state)  # identical weights

        ppo_train(model, trdl, d,
                  ppo_epochs=a.ppo_epochs,
                  episodes=a.ppo_episodes,
                  act_mode=ppo_mode,
                  switch_epoch=a.ppo_switch_epoch)

        after = eval_loader_metrics(model, tsdl, d)
        print("[After PPO] " + ", ".join([f"{k}={v:.4f}" for k, v in after.items()]))
        results.append([f"Immune({a.immune_act_mode}) + PPO({ppo_mode})",
                        after["MSE"], after["RMSE"], after["MAE"], after["R2"], after["Pearson"]])

    with open("gated_relu_comparison.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["Setup", "MSE", "RMSE", "MAE", "R2", "Pearson"])
        writer.writerows(results)
    print("\n✅ Results saved to gated_relu_comparison.csv")

if __name__ == "__main__":
    try:
        main()
    except SystemExit:
        pass



=== Training ImmuneNet (fixed convex) ===
[Epoch 1] Val MSE=0.0745, R2=-0.1134
[Epoch 2] Val MSE=0.0640, R2=0.0430
[Epoch 3] Val MSE=0.0638, R2=0.0462
[Epoch 4] Val MSE=0.0622, R2=0.0707
[Epoch 5] Val MSE=0.4577, R2=-5.8401
[Epoch 6] Val MSE=0.4577, R2=-5.8401
[Epoch 7] Val MSE=0.4577, R2=-5.8401
[Epoch 8] Val MSE=0.4577, R2=-5.8401
[Before PPO] MSE=0.4382, RMSE=0.6619, MAE=0.6062, R2=-5.2021, Pearson=0.0000

=== PPO Activation: CONVEX ===
[PPO] Episode 1/20 — Return=-0.1873
[PPO] Episode 2/20 — Return=-0.1861
[PPO] Episode 3/20 — Return=-0.1862
[PPO] Episode 4/20 — Return=-0.1861
[PPO] Episode 5/20 — Return=-0.1862
[PPO] Episode 6/20 — Return=-0.1861
[PPO] Episode 7/20 — Return=-0.1864
[PPO] Episode 8/20 — Return=-0.1862
[PPO] Episode 9/20 — Return=-0.1864
[PPO] Episode 10/20 — Return=-0.1863
[PPO] Episode 11/20 — Return=-0.1862
[PPO] Episode 12/20 — Return=-0.1860
[PPO] Episode 13/20 — Return=-0.1863
[PPO] Episode 14/20 — Return=-0.1862
[PPO] Episode 15/20 — Return=-0.1863
[PPO] Epi