In [4]:
"""
Immune RL + PPO Distillation (Dynamic Cytokine Control Comparison)
------------------------------------------------------------------
- One frozen ImmuneNet encoder (linear-only)
- PPO Teacher→Student distillation across modes: convex / nonconvex / twostage
- Dynamic evaluation (Option 2): cytokine control behavior during rollouts
- Saves: ppo_dynamic_cytokine_comparison.csv
"""

from __future__ import annotations
import os, csv, argparse, random, warnings, math
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
from scipy.stats import pearsonr

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ============================================================
# Torch setup
# ============================================================
HAS_TORCH = True
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader, Subset
except Exception:
    HAS_TORCH = False

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    if HAS_TORCH:
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def get_device():
    if not HAS_TORCH: return None
    if torch.cuda.is_available():
        d = torch.device("cuda")
        print(f"🟢 Using CUDA: {torch.cuda.get_device_name(0)}")
        return d
    print("🟡 Using CPU.")
    return torch.device("cpu")

# ============================================================
# Constants
# ============================================================
AA_VOCAB = list("ACDEFGHIKLMNPQRSTVWY")
AA_TO_ID = {a: i + 1 for i, a in enumerate(AA_VOCAB)}
CYTOKINES = ["NONE", "IL2", "IFNG", "IL10", "TNFA"]
CYTOKINE_TO_ID = {c: i for i, c in enumerate(CYTOKINES)}
IDX_NONE = CYTOKINE_TO_ID["NONE"]

# ============================================================
# Data utilities
# ============================================================
@dataclass
class Example:
    peptide: str
    allele: str
    score: float
    tcr: Optional[str] = None

def smart_read_table(path: str) -> List[List[str]]:
    rows = []
    with open(path, "r", newline="") as f:
        sample = f.read(2048); f.seek(0)
        import csv as _csv
        try: dialect = _csv.Sniffer().sniff(sample, delimiters="\t,;")
        except Exception:
            class dialect: delimiter = ","
        reader = _csv.reader(f, dialect)
        for row in reader:
            if row: rows.append([c.strip() for c in row])
    return rows

def load_alleles(path: str) -> Dict[str, int]:
    uniq = []
    for r in smart_read_table(path):
        for c in r:
            for token in c.replace(",", " ").split():
                if token and token not in uniq:
                    uniq.append(token)
    return {a: i for i, a in enumerate(sorted(uniq))}

def parse_examples(path: str, allele_to_id: Dict[str, int]) -> List[Example]:
    rows = smart_read_table(path); exs = []
    for r in rows:
        if len(r) < 3: continue
        pep, score_str, allele = r[0], r[1], r[2]
        try: score = float(score_str)
        except Exception: continue
        if allele not in allele_to_id:
            allele_to_id[allele] = len(allele_to_id)
        exs.append(Example(peptide=pep, allele=allele, score=score))
    return exs

class SeqTokenizer:
    def __init__(self, max_len=32): self.max_len=max_len
    def encode(self, s: str):
        ids = [AA_TO_ID.get(ch, 0) for ch in s[:self.max_len]]
        if len(ids) < self.max_len: ids += [0]*(self.max_len - len(ids))
        return torch.tensor(ids, dtype=torch.long)

class PeptideDataset(Dataset):
    def __init__(self, exs, allele_to_id, pep_len=32, tcr_len=24):
        self.exs = exs; self.allele_to_id = allele_to_id
        self.tok_p = SeqTokenizer(pep_len); self.tok_t = SeqTokenizer(tcr_len)
    def __len__(self): return len(self.exs)
    def __getitem__(self, i):
        e = self.exs[i]
        pep = self.tok_p.encode(e.peptide)
        tcr = self.tok_t.encode(e.tcr or "CASSIRSSYEQYF")
        all_idx = self.allele_to_id.get(e.allele, 0)
        return pep, tcr, all_idx, float(e.score)

def collate_pep(batch):
    pep, tcr, all_idx, y = zip(*batch)
    pep = torch.stack(pep); tcr = torch.stack(tcr)
    all_idx = torch.tensor(all_idx, dtype=torch.long)
    y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
    cytok = torch.zeros((pep.size(0), len(CYTOKINES)), dtype=torch.float32)
    cytok[:, IDX_NONE] = 1.0
    return pep, tcr, all_idx, cytok, y

# ============================================================
# Loss & Metrics
# ============================================================
#def dynamic_nonconvex_loss(pred, target, epoch=0, eps=1e-6, freq=4.0, amp=0.15, basin_depth=0.1):
def dynamic_nonconvex_loss(pred, target, epoch=0, eps=1e-6):
    e = pred - target
    base = torch.sqrt(torch.abs(e) + eps)
   # ripple = amp * torch.sin(freq * e) ** 2
   # basin = basin_depth * (e ** 4 - e ** 2)
   # decay = math.exp(-0.02 * epoch)     # faster decay
    #return torch.mean(base + decay * (ripple + basin))
    return torch.mean(base)

def regression_metrics(preds, targets):
    preds = np.asarray(preds).flatten(); targets = np.asarray(targets).flatten()
    mse = np.mean((preds - targets) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(preds - targets))
    denom = np.sum((targets - np.mean(targets)) ** 2) + 1e-12
    r2 = 1.0 - np.sum((targets - preds) ** 2) / denom
    pear = pearsonr(preds, targets)[0] if len(preds) > 1 else 0.0
    return {"MSE": float(mse), "RMSE": float(rmse), "MAE": float(mae), "R2": float(r2), "Pearson": float(pear)}

# ============================================================
# ImmuneNet (linear-only)
# ============================================================
class MiniGAT(nn.Module):
    def __init__(self, vocab_size, dim, max_len=32, heads=4, layers=2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size+1, dim)
        self.pos = nn.Embedding(max_len, dim)
        enc = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
        self.max_len = max_len
    def forward(self, x):
        L = min(x.size(1), self.max_len)
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(x.size(0), L)
        h = self.emb(x[:, :L]) + self.pos(pos)
        h = self.encoder(h)
        return torch.cat([h.mean(dim=1), h[:, 0, :]], dim=-1)

class ImmuneNet(nn.Module):
    def __init__(self, vocab_size, allele_count, dim=128, pep_len=32, tcr_len=24):
        super().__init__()
        self.pep_enc = MiniGAT(vocab_size, dim, pep_len)
        self.tcr_enc = MiniGAT(vocab_size, dim, tcr_len)
        self.all_emb = nn.Embedding(allele_count+1, dim)
        hid = 256
        in_dim = 2*dim + 2*dim + dim
        self.backbone = nn.Sequential(nn.Linear(in_dim, hid), nn.Linear(hid, hid))
        self.binding = nn.Linear(hid, 1)
        self.recognition = nn.Linear(hid, 1)
        self.cyt_fc = nn.Linear(len(CYTOKINES), 32)
        self.response = nn.Sequential(nn.Linear(hid+32, 128), nn.Linear(128, 1))
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.zeros_(m.bias)
    def encode_backbone(self, pep, tcr, all_idx):
        pep_h = self.pep_enc(pep); tcr_h = self.tcr_enc(tcr); all_h = self.all_emb(all_idx)
        return self.backbone(torch.cat([pep_h, tcr_h, all_h], dim=-1))
    def forward(self, pep, tcr, all_idx, cytok):
        z = self.encode_backbone(pep, tcr, all_idx)
        bind = torch.sigmoid(self.binding(z))
        recog = torch.sigmoid(self.recognition(z))
        c = self.cyt_fc(cytok)
        resp = self.response(torch.cat([z, c], dim=-1))
        return bind, recog, resp

# ============================================================
# Training ImmuneNet
# ============================================================
def train_supervised(model, tr_dl, val_dl, device, epochs=3, lr=2e-4):
    model.train(); opt = torch.optim.Adam(model.parameters(), lr=lr)
    for ep in range(1, epochs+1):
        losses=[]
        for pep,tcr,all_idx,cytok,y in tr_dl:
            pep,tcr,all_idx,cytok,y=[t.to(device) for t in (pep,tcr,all_idx,cytok,y)]
            bind,_,_=model(pep,tcr,all_idx,cytok)
            loss=dynamic_nonconvex_loss(bind,y,epoch=ep)
            opt.zero_grad(); loss.backward(); opt.step(); losses.append(loss.item())
        metrics=evaluate_supervised(model,val_dl,device)
        print(f"[Supervised] Epoch {ep} | Loss={np.mean(losses):.6f} | MSE={metrics['MSE']:.4f} | R2={metrics['R2']:.3f} | Pearson={metrics['Pearson']:.3f}")
    return model

@torch.no_grad()
def evaluate_supervised(model, dl, device):
    model.eval(); preds=[]; targs=[]
    for pep,tcr,all_idx,cytok,y in dl:
        pep,tcr,all_idx,cytok=[t.to(device) for t in (pep,tcr,all_idx,cytok)]
        _,_,resp=model(pep,tcr,all_idx,cytok)
        preds.append(resp.cpu().numpy()); targs.append(y.numpy())
    return regression_metrics(np.concatenate(preds), np.concatenate(targs))

# ============================================================
# Adaptive Cytokine Env
# ============================================================
class AdaptiveCytokineEnv(nn.Module):
    def __init__(self, num_cyt=len(CYTOKINES), rot_eps=0.02, base_bias=0.1, action_gain=0.25):
        super().__init__()
        self.num_cyt=num_cyt; self.rot_eps=rot_eps; self.base_bias=base_bias; self.action_gain=action_gain
        self.register_buffer("R_t", torch.eye(num_cyt)); self.register_buffer("b_t", torch.zeros(num_cyt))
    @torch.no_grad()
    def set_epoch(self,ep:int):
        M=torch.randn(self.num_cyt,self.num_cyt,device=self.R_t.device)
        Q,R=torch.linalg.qr(M); self.R_t=Q*torch.sign(torch.diag(R))
        self.b_t=self.base_bias*torch.randn(self.num_cyt,device=self.R_t.device)
    @torch.no_grad()
    def _drift(self):
        Q,R=torch.linalg.qr(self.R_t+self.rot_eps*torch.randn_like(self.R_t))
        self.R_t=Q*torch.sign(torch.diag(R))
        self.b_t=0.95*self.b_t+0.05*torch.randn_like(self.b_t)
    @torch.no_grad()
    def step(self,model,pep,tcr,all_idx,a,z,d,cytok_prev=None):
        B,N=pep.size(0),self.num_cyt
        cytok=torch.zeros((B,N),device=d) if cytok_prev is None else cytok_prev.clone()
        cytok[:,IDX_NONE]=1.0
        a_rot=a@self.R_t.to(d)
        cytok_next=torch.clamp(cytok+self.action_gain*a_rot+self.b_t.to(d),0.0,1.0)
        _,recog,resp=model(pep,tcr,all_idx,cytok_next)
        reward=(0.7*resp.squeeze(-1)+0.3*recog.squeeze(-1)).detach()
        self._drift()
        return reward,cytok_next

# ============================================================
# PPO Policy
# ============================================================
class PPOPolicy(nn.Module):
    def __init__(self,input_dim,num_actions,width=256,depth=2,mode="convex",std=0.1,switch_epoch=3):
        super().__init__()
        self.mode=mode; self.switch_epoch=switch_epoch; self._ep=0; self._std=std
        def mlp(out_dim):
            layers=[]; d_in=input_dim
            for _ in range(depth): layers+=[nn.Linear(d_in,width)]; d_in=width
            layers.append(nn.Linear(d_in,out_dim)); return nn.ModuleList(layers)
        self.actor_layers=mlp(num_actions); self.critic_layers=mlp(1)
        for m in self.modules():
            if isinstance(m,nn.Linear):
                nn.init.orthogonal_(m.weight,gain=math.sqrt(2)); nn.init.zeros_(m.bias)
    def set_epoch(self,ep:int): self._ep=ep
    def _apply_stack(self,layers,x):
        for layer in layers[:-1]: x=layer(x)
        return layers[-1](x)
    def actor_forward(self,x): return self._apply_stack(self.actor_layers,x)
    def critic_forward(self,x): return self._apply_stack(self.critic_layers,x).squeeze(-1)
    def act(self,x):
        mu=self.actor_forward(x)
        dist=torch.distributions.Normal(mu,self._std)
        u=dist.rsample(); a=torch.tanh(u)
        logp=dist.log_prob(u).sum(-1)-torch.log(1-a.pow(2)+1e-6).sum(-1)
        v=self.critic_forward(x); return a,logp,v
    def act_deterministic(self,x):
        mu=self.actor_forward(x); v=self.critic_forward(x)
        return torch.tanh(mu),v

# ============================================================
# PPO training & distillation
# ============================================================
def ppo_train_multistep(policy,make_batch,env,immune,d,episodes=3,ppo_epochs=8,multistep_steps=5,
                        gamma=0.99,clip=0.2,max_grad_norm=0.5,lr=5e-4):
    opt=torch.optim.Adam(policy.parameters(),lr=lr)
    for ep in range(episodes):
        policy.set_epoch(ep); env.set_epoch(ep)
        S,A,OL,Gt,Adv=[],[],[],[],[]
        while True:
            batch=make_batch()
            if batch is None: break
            z=batch["state"].to(d); pep,tcr,all_idx,cytok=[batch[k].to(d) for k in ("pep","tcr","all","cytok_init")]
            with torch.no_grad():
                disc=torch.zeros(z.size(0),device=d); cy=cytok
                for t in range(multistep_steps):
                    a,lp,v=policy.act(z)
                    r,cy=env.step(immune,pep,tcr,all_idx,a,z,d,cy)
                    disc+=(gamma**t)*r
                _,vL=policy.act_deterministic(z)
                Aadv=(disc-vL).detach()
            S.append(z);A.append(a);OL.append(lp);Gt.append(disc);Adv.append(Aadv)
        if not S: continue
        S,A,OL,Gt,Adv=map(torch.cat,(S,A,OL,Gt,Adv))
        Adv=(Adv-Adv.mean())/(Adv.std()+1e-8)
        for _ in range(ppo_epochs):
            mu=policy.actor_forward(S)
            dist=torch.distributions.Normal(mu,policy._std)
            A_clamp=torch.clamp(A,-1+1e-6,1-1e-6)
            u=0.5*(torch.log1p(A_clamp)-torch.log1p(-A_clamp))
            nlp=dist.log_prob(u).sum(-1)-torch.log(1-A_clamp.pow(2)+1e-6).sum(-1)
            ratio=torch.exp(nlp-OL)
            aloss=-torch.min(ratio*Adv,torch.clamp(ratio,1-clip,1+clip)*Adv).mean()
            vpred=policy.critic_forward(S)
            closs=dynamic_nonconvex_loss(vpred,Gt,epoch=ep)
            entropy=dist.entropy().sum(-1).mean()
            loss=aloss+0.5*closs-0.01*entropy
            opt.zero_grad();loss.backward()
            if max_grad_norm>0: torch.nn.utils.clip_grad_norm_(policy.parameters(),max_grad_norm)
            opt.step()
        print(f"[PPO/{policy.mode}] Ep {ep+1}/{episodes} Return={Gt.mean():.4f}")
    return policy

def ppo_distill_loss(mu_s,mu_t,v_s,v_t,temp=1.0,alpha=0.7,std=0.1,mode="convex",epoch=0):
    var=(std*temp)**2
    kl=((mu_s-mu_t)**2/(2*var)).mean()
    v_mse=F.mse_loss(v_s,v_t)
    if mode=="convex":
        return alpha*kl + (1-alpha)*v_mse
    elif mode=="nonconvex":
        ripple = 0.10 * torch.sin(6.0 * (v_s - v_t))**2
        basin  = 0.05 * ((v_s - v_t)**4 - (v_s - v_t)**2)
        return alpha*kl + (1-alpha)*(v_mse + ripple.mean() + basin.mean())
    elif mode=="twostage":
        lam = torch.tensor(min(1.0, epoch/10.0), dtype=torch.float32, device=v_s.device)
        ripple = 0.08 * torch.sin(6.0 * (v_s - v_t))**2
        basin  = 0.04 * ((v_s - v_t)**4 - (v_s - v_t)**2)
        nonconv = v_mse + ripple.mean() + basin.mean()
        return alpha*kl + (1-alpha)*((1-lam)*v_mse + lam*nonconv)
    else:
        return alpha*kl + (1-alpha)*v_mse

def ppo_distill(teacher,student,make_batch,d,epochs=6,lr=1e-4,temp=1.0,alpha=0.7):
    teacher.eval(); opt=torch.optim.Adam(student.parameters(),lr=lr)
    for ep in range(1,epochs+1):
        student.set_epoch(ep); total=0.0; nb=0
        while True:
            batch=make_batch()
            if batch is None: break
            z=batch["state"].to(d)
            with torch.no_grad():
                mu_t=teacher.actor_forward(z); _,v_t=teacher.act_deterministic(z)
            mu_s=student.actor_forward(z); _,v_s=student.act_deterministic(z)
            loss=ppo_distill_loss(mu_s,mu_t,v_s,v_t,temp=temp,alpha=alpha,std=teacher._std,mode=teacher.mode,epoch=ep)
            opt.zero_grad(); loss.backward(); opt.step()
            total+=loss.item(); nb+=1
        print(f"[Distill/{teacher.mode}] Ep {ep:02d}/{epochs} loss={total/max(nb,1):.6f}")
    return student

# ============================================================
# Dynamic Evaluation: Cytokine Control (Option 2) with full stats
# ============================================================
@torch.no_grad()
def evaluate_dynamic_cytokine(student, make_batch, env, immune_model, d,
                              episodes=3, multistep_steps=10, gamma=0.99):
    """
    Runs rollouts and computes dynamic cytokine control metrics:
      - Return (mean cumulative reward)
      - ΔCyt (mean per-step absolute change in cytokines, excluding NONE)
      - Stability (1 / variance of cytokine magnitude)
      - Corr(Reward, ||Cyt||) over time
      - R2 between Immune Response and Reward
      - MSE, RMSE, MAE, R2, Pearson between cytokine magnitude and reward
    """
    returns = []
    delta_c_list = []
    stability_list = []
    corr_list = []
    r2_list = []
    mse_list, rmse_list, mae_list, r2_cyt_list, pear_list = [], [], [], [], []

    for ep in range(episodes):
        student.set_epoch(ep); env.set_epoch(ep)
        R_t_series = []
        C_t_series = []
        Resp_t_series = []

        while True:
            batch = make_batch()
            if batch is None: break
            z = batch["state"].to(d)
            pep,tcr,all_idx,cytok = [batch[k].to(d) for k in ("pep","tcr","all","cytok_init")]
            ep_return = torch.zeros(z.size(0), device=d)
            cy = cytok

            for t in range(multistep_steps):
                a, v = student.act_deterministic(z)
                r, cy = env.step(immune_model, pep, tcr, all_idx, a, z, d, cy)
                _, _, resp = immune_model(pep, tcr, all_idx, cy)
                ep_return += (gamma**t) * r

                # track per-step series
                R_t_series.append(r.detach().cpu().numpy())
                C_t_series.append(cy.detach().cpu().numpy())
                Resp_t_series.append(resp.squeeze(-1).detach().cpu().numpy())

            returns.append(ep_return.mean().item())

        if not C_t_series:
            continue

        # Convert to arrays [T*B, ...]
        C = np.concatenate(C_t_series, axis=0)          # (TB, num_cyt)
        R = np.concatenate(R_t_series, axis=0).ravel()  # (TB,)
        RESP = np.concatenate(Resp_t_series, axis=0).ravel()

        # ΔCyt: mean L1 change per step (excluding NONE channel)
        C_eff = C[:, 1:]  # exclude NONE column
        if len(C_eff) > 1:
            dC = np.abs(np.diff(C_eff, axis=0))
            delta_c = float(np.mean(dC))
            # Stability: 1 / variance of cytokine magnitude over time
            C_mag = np.linalg.norm(C_eff, axis=1)  # Euclidean magnitude
            stability = float(1.0 / (np.var(C_mag) + 1e-8))
        else:
            delta_c, stability = 0.0, 0.0

        # Corr(Reward, ||C||)
        try:
            C_mag = np.linalg.norm(C_eff, axis=1)
            corr = float(pearsonr(R, C_mag)[0])
        except Exception:
            corr = 0.0

        # R2(Resp, Reward)
        try:
            denom = np.sum((R - np.mean(R)) ** 2) + 1e-12
            r2 = float(1.0 - np.sum((R - RESP) ** 2) / denom)
        except Exception:
            r2 = 0.0

        # Cytokine-vs-Reward stats (magnitude vs reward)
        try:
            cyt_norm = np.linalg.norm(C_eff, axis=1)
            mse = float(np.mean((cyt_norm - R) ** 2))
            rmse = float(np.sqrt(mse))
            mae = float(np.mean(np.abs(cyt_norm - R)))
            denom_c = np.sum((cyt_norm - np.mean(cyt_norm)) ** 2) + 1e-12
            r2_cyt = float(1.0 - np.sum((cyt_norm - R) ** 2) / denom_c)
            pear = float(pearsonr(cyt_norm, R)[0])
        except Exception:
            mse = rmse = mae = r2_cyt = pear = 0.0

        delta_c_list.append(delta_c)
        stability_list.append(stability)
        corr_list.append(corr)
        r2_list.append(r2)
        mse_list.append(mse); rmse_list.append(rmse); mae_list.append(mae)
        r2_cyt_list.append(r2_cyt); pear_list.append(pear)

    metrics = {
        "Return": float(np.mean(returns) if returns else 0.0),
        "DeltaCyt": float(np.mean(delta_c_list) if delta_c_list else 0.0),
        "Stability": float(np.mean(stability_list) if stability_list else 0.0),
        "CorrRewardCyt": float(np.mean(corr_list) if corr_list else 0.0),
        "R2RespReward": float(np.mean(r2_list) if r2_list else 0.0),
        "MSE": float(np.mean(mse_list) if mse_list else 0.0),
        "RMSE": float(np.mean(rmse_list) if rmse_list else 0.0),
        "MAE": float(np.mean(mae_list) if mae_list else 0.0),
        "R2": float(np.mean(r2_cyt_list) if r2_cyt_list else 0.0),
        "Pearson": float(np.mean(pear_list) if pear_list else 0.0),
    }
    return metrics

# ============================================================
# Main
# ============================================================
def main():
    import sys
    if any(a.startswith("-f") for a in sys.argv): sys.argv=[sys.argv[0]]
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", default="train_BA1.txt")
    parser.add_argument("--test", default="test_BA1.txt")
    parser.add_argument("--alleles", default="allelelist.txt")
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--batch_size", type=int, default=64)
    # PPO config
    parser.add_argument("--ppo_multistep_steps", type=int, default=2)
    parser.add_argument("--ppo_gamma", type=float, default=0.99)
    parser.add_argument("--ppo_clip", type=float, default=0.2)
    parser.add_argument("--ppo_std", type=float, default=0.1)
    parser.add_argument("--ppo_teacher_width", type=int, default=2)
    parser.add_argument("--ppo_teacher_depth", type=int, default=2)
    parser.add_argument("--ppo_teacher_episodes", type=int, default=2)
    parser.add_argument("--ppo_teacher_epochs", type=int, default=8)
    parser.add_argument("--ppo_teacher_lr", type=float, default=1e-4)
    parser.add_argument("--ppo_student_width", type=int, default=2)
    parser.add_argument("--ppo_student_depth", type=int, default=2)
    parser.add_argument("--ppo_distill_epochs", type=int, default=10)
    parser.add_argument("--ppo_distill_lr", type=float, default=1e-4)
    parser.add_argument("--ppo_distill_temp", type=float, default=1.0)
    parser.add_argument("--ppo_distill_alpha", type=float, default=0.7)
    parser.add_argument("--modes", type=str, default="convex,nonconvex,twostage")
    args,_ = parser.parse_known_args()

    d = get_device(); seed_everything(1)
    if not HAS_TORCH: 
        print("❌ PyTorch not available"); return

    # Load data
    allele_to_id = load_alleles(args.alleles)
    tr = parse_examples(args.train, allele_to_id)
    ts = parse_examples(args.test, allele_to_id)

    tr_ds = PeptideDataset(tr, allele_to_id); ts_ds = PeptideDataset(ts, allele_to_id)
    N=len(tr_ds); idx=list(range(N)); random.shuffle(idx)
    val=max(1,int(0.2*N))
    tr_dl=DataLoader(Subset(tr_ds,idx[val:]),batch_size=args.batch_size,shuffle=True,collate_fn=collate_pep)
    val_dl=DataLoader(Subset(tr_ds,idx[:val]),batch_size=args.batch_size,shuffle=False,collate_fn=collate_pep)
    ts_dl=DataLoader(ts_ds,batch_size=args.batch_size,shuffle=False,collate_fn=collate_pep)

    # Train ImmuneNet (linear-only) and freeze for PPO
    immune=ImmuneNet(len(AA_VOCAB),len(allele_to_id)).to(d)
    print("=== Supervised training ImmuneNet (linear-only) ===")
    train_supervised(immune,tr_dl,val_dl,d,epochs=args.epochs)
    test_metrics = evaluate_supervised(immune, ts_dl, d)
    print(
        f"[Test ImmuneNet] MSE={test_metrics['MSE']:.4f} | RMSE={test_metrics['RMSE']:.4f} | "
        f"MAE={test_metrics['MAE']:.4f} | R2={test_metrics['R2']:.3f} | "
        f"Pearson={test_metrics['Pearson']:.3f}"
    )
    immune.eval()

    # Batch builder for PPO
    def make_batch_gen():
        def make_batch():
            if not hasattr(make_batch,"it"): make_batch.it=iter(tr_dl)
            try: pep,tcr,all_idx,cytok,y=next(make_batch.it)
            except StopIteration:
                make_batch.it=iter(tr_dl); return None
            with torch.no_grad():
                z=immune.encode_backbone(pep.to(d),tcr.to(d),all_idx.to(d)).detach()
            return {"state":z,"pep":pep,"tcr":tcr,"all":all_idx,"cytok_init":cytok}
        return make_batch

    # Compare modes
    modes_list=[m.strip() for m in args.modes.split(",") if m.strip()]
    results_metrics=[]  # list of (mode, metrics_dict)
    for mode in modes_list:
        print(f"\n===== Mode: {mode.upper()} =====")
        env=AdaptiveCytokineEnv().to(d)

        # Teacher
        teacher=PPOPolicy(256,len(CYTOKINES),
                          width=args.ppo_teacher_width, depth=args.ppo_teacher_depth,
                          mode=mode, std=args.ppo_std).to(d)
        make_batch=make_batch_gen()
        teacher=ppo_train_multistep(teacher,make_batch,env,immune,d,
                                    episodes=args.ppo_teacher_episodes,
                                    ppo_epochs=args.ppo_teacher_epochs,
                                    multistep_steps=args.ppo_multistep_steps,
                                    gamma=args.ppo_gamma, clip=args.ppo_clip,
                                    lr=args.ppo_teacher_lr)

        # Student (distillation)
        student=PPOPolicy(256,len(CYTOKINES),
                          width=args.ppo_student_width, depth=args.ppo_student_depth,
                          mode=mode, std=args.ppo_std).to(d)
        make_batch=make_batch_gen()
        student=ppo_distill(teacher,student,make_batch,d,
                            epochs=args.ppo_distill_epochs, lr=args.ppo_distill_lr,
                            temp=args.ppo_distill_temp, alpha=args.ppo_distill_alpha)

        # Dynamic evaluation (Option 2)
        make_batch=make_batch_gen()
        dyn = evaluate_dynamic_cytokine(student, make_batch, env, immune, d,
                                        episodes=3, multistep_steps=args.ppo_multistep_steps,
                                        gamma=args.ppo_gamma)

        print(
          f"[Dynamic/{mode}] Return={dyn['Return']:.4f} | "
          f"ΔCyt={dyn['DeltaCyt']:.4f} | Stability={dyn['Stability']:.3f} | "
          f"Corr(Reward,||C||)={dyn['CorrRewardCyt']:.3f} | R2(Resp,Reward)={dyn['R2RespReward']:.3f} | "
          f"MSE={dyn['MSE']:.4f} | RMSE={dyn['RMSE']:.4f} | MAE={dyn['MAE']:.4f} | "
          f"R2={dyn['R2']:.3f} | Pearson={dyn['Pearson']:.3f}"
        )

        # Save student for reuse
        torch.save(student.state_dict(),f"ppo_student_{mode}.pt")

        results_metrics.append((mode, dyn))

    # Write comparison CSV
    out_csv = "ppo_dynamic_cytokine_comparison.csv"
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow([
            "Mode","Return","DeltaCyt","Stability",
            "CorrRewardCyt","R2RespReward",
            "MSE","RMSE","MAE","R2","Pearson"
        ])
        for mode, dyn in results_metrics:
            w.writerow([
                mode, dyn["Return"], dyn["DeltaCyt"], dyn["Stability"],
                dyn["CorrRewardCyt"], dyn["R2RespReward"],
                dyn["MSE"], dyn["RMSE"], dyn["MAE"], dyn["R2"], dyn["Pearson"]
            ])
    print(f"✅ Saved {out_csv}")

if __name__=="__main__":
    try: 
        main()
    except SystemExit: 
        pass


🟡 Using CPU.
=== Supervised training ImmuneNet (linear-only) ===
[Supervised] Epoch 1 | Loss=0.416644 | MSE=0.2204 | R2=-2.294 | Pearson=0.143
[Supervised] Epoch 2 | Loss=0.398956 | MSE=0.3443 | R2=-4.145 | Pearson=0.180
[Test ImmuneNet] MSE=0.3533 | RMSE=0.5944 | MAE=0.4947 | R2=-4.001 | Pearson=0.148

===== Mode: CONVEX =====
[PPO/convex] Ep 1/2 Return=0.2883
[PPO/convex] Ep 2/2 Return=0.2700
[Distill/convex] Ep 01/10 loss=33.465093
[Distill/convex] Ep 02/10 loss=24.684720
[Distill/convex] Ep 03/10 loss=20.100252
[Distill/convex] Ep 04/10 loss=16.113921
[Distill/convex] Ep 05/10 loss=12.540359
[Distill/convex] Ep 06/10 loss=9.232134
[Distill/convex] Ep 07/10 loss=6.122596
[Distill/convex] Ep 08/10 loss=3.245446
[Distill/convex] Ep 09/10 loss=1.290944
[Distill/convex] Ep 10/10 loss=0.435034
[Dynamic/convex] Return=0.2796 | ΔCyt=0.1102 | Stability=13.405 | Corr(Reward,||C||)=0.017 | R2(Resp,Reward)=0.354 | MSE=0.1925 | RMSE=0.4384 | MAE=0.3434 | R2=-1.579 | Pearson=0.017

===== Mode: N