In [None]:
"""
Immune RL + PPO Distillation (Waddington Landscape)
---------------------------------------------------
- Gated + Adaptive-Threshold BioActivation (convex / nonconvex / twostage)
- Dynamic Nonconvex Loss (Waddington landscape for critic stability)
- Compares PPO modes on critic value metrics
"""

from __future__ import annotations
import os, csv, argparse, random, warnings, math
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Optional
from copy import deepcopy
from scipy.stats import pearsonr

warnings.filterwarnings("ignore", category=RuntimeWarning)

# ============================================================
# Torch setup
# ============================================================
HAS_TORCH = True
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader, Subset
except Exception:
    HAS_TORCH = False

def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    if HAS_TORCH:
        torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def get_device():
    if not HAS_TORCH: return None
    if torch.cuda.is_available():
        d = torch.device("cuda")
        print(f"🟢 Using CUDA: {torch.cuda.get_device_name(0)}")
        return d
    print("🟡 Using CPU.")
    return torch.device("cpu")

# ============================================================
# Constants
# ============================================================
AA_VOCAB = list("ACDEFGHIKLMNPQRSTVWY")
AA_TO_ID = {a: i + 1 for i, a in enumerate(AA_VOCAB)}
CYTOKINES = ["NONE", "IL2", "IFNG", "IL10", "TNFA"]
CYTOKINE_TO_ID = {c: i for i, c in enumerate(CYTOKINES)}

# ============================================================
# Data utilities
# ============================================================
@dataclass
class Example:
    peptide: str
    allele: str
    score: float
    tcr: Optional[str] = None

def smart_read_table(path: str) -> List[List[str]]:
    rows = []
    with open(path, "r", newline="") as f:
        sample = f.read(2048); f.seek(0)
        import csv as _csv
        try: dialect = _csv.Sniffer().sniff(sample, delimiters="\t,;")
        except Exception:
            class dialect: delimiter = ","
        reader = _csv.reader(f, dialect)
        for row in reader:
            if row: rows.append([c.strip() for c in row])
    return rows

def load_alleles(path: str) -> Dict[str, int]:
    uniq = []
    for r in smart_read_table(path):
        for c in r:
            for token in c.replace(",", " ").split():
                if token and token not in uniq:
                    uniq.append(token)
    return {a: i for i, a in enumerate(sorted(uniq))}

def parse_examples(path: str, allele_to_id: Dict[str, int]) -> List[Example]:
    rows = smart_read_table(path); exs = []
    for r in rows:
        if len(r) < 3: continue
        pep, score_str, allele = r[0], r[1], r[2]
        try: score = float(score_str)
        except Exception: continue
        if allele not in allele_to_id:
            allele_to_id[allele] = len(allele_to_id)
        exs.append(Example(peptide=pep, allele=allele, score=score))
    return exs

class SeqTokenizer:
    def __init__(self, max_len=32): self.max_len=max_len
    def encode(self, s: str):
        ids = [AA_TO_ID.get(ch, 0) for ch in s[:self.max_len]]
        if len(ids) < self.max_len: ids += [0]*(self.max_len - len(ids))
        return torch.tensor(ids, dtype=torch.long)

class PeptideDataset(Dataset):
    def __init__(self, exs, allele_to_id, pep_len=32, tcr_len=24):
        self.exs = exs; self.allele_to_id = allele_to_id
        self.tok_p = SeqTokenizer(pep_len); self.tok_t = SeqTokenizer(tcr_len)
    def __len__(self): return len(self.exs)
    def __getitem__(self, i):
        e = self.exs[i]
        pep = self.tok_p.encode(e.peptide)
        tcr = self.tok_t.encode(e.tcr or "CASSIRSSYEQYF")
        all_idx = self.allele_to_id.get(e.allele, 0)
        return pep, tcr, all_idx, float(e.score)

def collate_pep(batch):
    pep, tcr, all_idx, y = zip(*batch)
    pep = torch.stack(pep); tcr = torch.stack(tcr)
    all_idx = torch.tensor(all_idx, dtype=torch.long)
    y = torch.tensor(y, dtype=torch.float32).unsqueeze(-1)
    cytok = torch.zeros((pep.size(0), len(CYTOKINES)), dtype=torch.float32)
    cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
    return pep, tcr, all_idx, cytok, y

# ============================================================
# Dynamic Nonconvex Loss (Waddington Landscape)
# ============================================================
def dynamic_nonconvex_loss(pred, target, epoch=0, eps=1e-6, freq=6.0, amp=0.3, basin_depth=0.2):
    e = pred - target
    base = torch.sqrt(torch.abs(e) + eps)
    ripple = amp * torch.sin(freq * e) ** 2
    basin = basin_depth * (e ** 4 - e ** 2)
    decay = math.exp(-0.01 * epoch)
    return torch.mean(base + decay * (ripple + basin))

def regression_metrics(preds, targets):
    preds = np.asarray(preds).flatten(); targets = np.asarray(targets).flatten()
    mse = np.mean((preds - targets) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(preds - targets))
    denom = np.sum((targets - np.mean(targets)) ** 2) + 1e-12
    r2 = 1.0 - np.sum((targets - preds) ** 2) / denom
    pear = pearsonr(preds, targets)[0] if len(preds) > 1 else 0.0
    return {"MSE": float(mse), "RMSE": float(rmse), "MAE": float(mae), "R2": float(r2), "Pearson": float(pear)}

# ============================================================
# Gated + Adaptive Threshold BioActivation
# ============================================================
class BioActivation(nn.Module):
    def __init__(self, dim=128, mode="nonconvex", switch_epoch=5, adaptive=True):
        super().__init__()
        self.mode = mode; self.switch_epoch = switch_epoch; self._ep = 0
        self.gate = nn.Linear(dim, dim)
        if adaptive: self.threshold = nn.Parameter(torch.zeros(dim))
        else: self.register_buffer("threshold", torch.zeros(dim))
    def set_epoch(self, ep:int): self._ep = ep
    def forward(self, x):
        g = torch.sigmoid(self.gate(x))
        x_shift = x - self.threshold
        if self.mode == "convex":
            base = F.softplus(x_shift)
        elif self.mode == "nonconvex":
            base = x_shift * torch.sigmoid(x_shift)
        elif self.mode == "twostage":
            base = F.softplus(x_shift) if self._ep < self.switch_epoch else x_shift * torch.sigmoid(x_shift)
        else:
            raise ValueError(f"Unknown mode {self.mode}")
        return base * g

# ============================================================
# ImmuneNet backbone
# ============================================================
class MiniGAT(nn.Module):
    def __init__(self, vocab_size, dim, max_len=32, heads=4, layers=2):
        super().__init__()
        self.emb = nn.Embedding(vocab_size+1, dim)
        self.pos = nn.Embedding(max_len, dim)
        enc = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
        self.max_len = max_len
    def forward(self, x):
        L = min(x.size(1), self.max_len)
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(x.size(0), L)
        h = self.emb(x[:, :L]) + self.pos(pos)
        h = self.encoder(h)
        return torch.cat([h.mean(dim=1), h[:, 0, :]], dim=-1)

class ImmuneNet(nn.Module):
    def __init__(self, vocab_size, allele_count, dim=128, pep_len=32, tcr_len=24,
                 act_mode="nonconvex", switch_epoch=5):
        super().__init__()
        self.pep_enc = MiniGAT(vocab_size, dim, pep_len)
        self.tcr_enc = MiniGAT(vocab_size, dim, tcr_len)
        self.all_emb = nn.Embedding(allele_count+1, dim)
        hid = 256; in_dim = 2*dim + 2*dim + dim
        def act(): return BioActivation(dim=hid, mode=act_mode, switch_epoch=switch_epoch)
        self.backbone = nn.Sequential(nn.Linear(in_dim, hid), act(), nn.Linear(hid, hid), act())
        self.binding = nn.Linear(hid, 1)
        self.recognition = nn.Linear(hid, 1)
        self.cyt_fc = nn.Linear(len(CYTOKINES), 32)
        self.response = nn.Sequential(nn.Linear(hid+32, 128), nn.ReLU(), nn.Linear(128, 1))
    def encode_backbone(self, pep, tcr, all_idx):
        pep_h = self.pep_enc(pep); tcr_h = self.tcr_enc(tcr); all_h = self.all_emb(all_idx)
        return self.backbone(torch.cat([pep_h, tcr_h, all_h], dim=-1))
    def forward(self, pep, tcr, all_idx, cytok):
        z = self.encode_backbone(pep, tcr, all_idx)
        bind = torch.sigmoid(self.binding(z)); recog = torch.sigmoid(self.recognition(z))
        c = F.relu(self.cyt_fc(cytok)); resp = self.response(torch.cat([z, c], dim=-1))
        return bind, recog, resp

@torch.no_grad()
def eval_metrics(model, loader, d):
    model.eval(); preds, targets = [], []
    for pep, tcr, all_idx, cytok, y in loader:
        pep, tcr, all_idx, cytok = pep.to(d), tcr.to(d), all_idx.to(d), cytok.to(d)
        bind, _, _ = model(pep, tcr, all_idx, cytok)
        preds.extend(bind.cpu().numpy()); targets.extend(y.numpy())
    return regression_metrics(preds, targets)

def train_supervised(model, tr, val, d, epochs=3, lr=1e-3):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for ep in range(epochs):
        for m in model.modules():
            if isinstance(m, BioActivation): m.set_epoch(ep)
        model.train()
        for pep,tcr,all_idx,cytok,y in tr:
            pep,tcr,all_idx,cytok,y=[x.to(d) for x in (pep,tcr,all_idx,cytok,y)]
            bind,_,_ = model(pep,tcr,all_idx,cytok)
            loss = dynamic_nonconvex_loss(bind, y, epoch=ep)
            opt.zero_grad(); loss.backward(); opt.step()
        vm = eval_metrics(model, val, d)
        print(f"[ImmuneNet] Epoch {ep+1} | Val MSE={vm['MSE']:.5f}")

        
# ============================================================
# Adaptive Cytokine Environment (non-stationary drift)
# ============================================================
class AdaptiveCytokineEnv(nn.Module):
    def __init__(self, num_cyt=len(CYTOKINES), rot_eps=0.02, base_bias=0.1, action_gain=0.25):
        super().__init__()
        self.num_cyt = num_cyt
        self.rot_eps = rot_eps
        self.base_bias = base_bias
        self.action_gain = action_gain
        self.register_buffer("R_t", torch.eye(num_cyt))
        self.register_buffer("b_t", torch.zeros(num_cyt))

    @torch.no_grad()
    def set_epoch(self, ep:int):
        # random re-orthogonalize rotation each epoch
        M = torch.randn(self.num_cyt, self.num_cyt)
        Q, R = torch.linalg.qr(M)
        self.R_t = Q * torch.sign(torch.diag(R))
        self.b_t = self.base_bias * torch.randn(self.num_cyt)

    @torch.no_grad()
    def _drift(self):
        Q, R = torch.linalg.qr(self.R_t + self.rot_eps * torch.randn_like(self.R_t))
        self.R_t = Q * torch.sign(torch.diag(R))
        self.b_t = 0.95 * self.b_t + 0.05 * torch.randn_like(self.b_t)

    @torch.no_grad()
    def step(self, model, pep, tcr, all_idx, a, z, d, cytok_prev=None):
        B, N = pep.size(0), self.num_cyt
        cytok = torch.zeros((B, N), device=d) if cytok_prev is None else cytok_prev.clone()
        cytok[:, CYTOKINE_TO_ID["NONE"]] = 1.0
        a_rot = a @ self.R_t.to(d)
        cytok_next = torch.clamp(cytok + self.action_gain * a_rot + self.b_t.to(d), 0.0, 1.0)
        _, recog, resp = model(pep, tcr, all_idx, cytok_next)
        reward = (0.7 * resp.squeeze(-1) + 0.3 * recog.squeeze(-1)).detach()
        self._drift()
        return reward, cytok_next


# ============================================================
# PPO Policy (teacher/student)
# ============================================================
class PPOPolicy(nn.Module):
    def __init__(self, input_dim, num_actions, width=256, depth=2,
                 mode="nonconvex", std=0.1, switch_epoch=3):
        super().__init__()
        self.mode = mode
        self.switch_epoch = switch_epoch
        self._ep = 0
        self.relu = nn.ReLU(); self.swish = nn.SiLU()
        self._std = std

        def mlp(out_dim):
            layers = []
            d_in = input_dim
            for _ in range(depth):
                layers += [nn.Linear(d_in, width)]
                d_in = width
            layers.append(nn.Linear(d_in, out_dim))
            return nn.ModuleList(layers)

        self.actor_layers = mlp(num_actions)
        self.critic_layers = mlp(1)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
                nn.init.zeros_(m.bias)

    def set_epoch(self, ep:int): self._ep = ep

    def _apply_stack(self, layers, x):
        for layer in layers[:-1]:
            x = layer(x)
            if self.mode == "twostage":
                x = self.swish(x) if self._ep >= self.switch_epoch else self.relu(x)
            elif self.mode == "convex":
                x = self.relu(x)
            else:
                x = self.swish(x)
        return layers[-1](x)

    def actor_forward(self, x):  return self._apply_stack(self.actor_layers, x)
    def critic_forward(self, x): return self._apply_stack(self.critic_layers, x)

    def act(self, x):
        mu = self.actor_forward(x)
        dist = torch.distributions.Normal(mu, self._std)
        u = dist.rsample()
        a = torch.tanh(u)
        logp = dist.log_prob(u).sum(-1) - torch.log(1 - a.pow(2) + 1e-6).sum(-1)
        v = self.critic_forward(x).squeeze(-1)
        return a, logp, v

    def act_deterministic(self, x):
        mu = self.actor_forward(x)
        v = self.critic_forward(x).squeeze(-1)
        return torch.tanh(mu), v


# ============================================================
# PPO Teacher Training
# ============================================================
def ppo_train_multistep(policy, make_batch, env, immune_model, d,
                        episodes=5, ppo_epochs=10, multistep_steps=5,
                        gamma=0.99, clip=0.2, max_grad_norm=0.5, lr=5e-4):
    opt = torch.optim.Adam(policy.parameters(), lr=lr)
    ent_coef, vf_coef = 0.01, 0.5
    for ep in range(episodes):
        policy.set_epoch(ep); env.set_epoch(ep)
        S,A,OL,Gt,Adv=[],[],[],[],[]
        while True:
            batch = make_batch()
            if batch is None: break
            z = batch["state"].to(d)
            pep,tcr,all_idx,cytok = [batch[k].to(d) for k in ("pep","tcr","all","cytok_init")]
            with torch.no_grad():
                disc = torch.zeros(z.size(0), device=d)
                cy = cytok
                for t in range(multistep_steps):
                    a, lp, v = policy.act(z)
                    r, cy = env.step(immune_model, pep, tcr, all_idx, a, z, d, cy)
                    disc += (gamma**t) * r
                _, vL = policy.act_deterministic(z)
                Aadv = (disc - vL).detach()
            S.append(z); A.append(a); OL.append(lp); Gt.append(disc); Adv.append(Aadv)
        if not S: continue
        S,A,OL,Gt,Adv = map(torch.cat, (S,A,OL,Gt,Adv))
        Adv = (Adv - Adv.mean()) / (Adv.std() + 1e-8)
        for _ in range(ppo_epochs):
            mu = policy.actor_forward(S)
            dist = torch.distributions.Normal(mu, policy._std)
            A_clamp = torch.clamp(A, -1 + 1e-6, 1 - 1e-6)
            u = 0.5 * (torch.log1p(A_clamp) - torch.log1p(-A_clamp))
            nlp = dist.log_prob(u).sum(-1) - torch.log(1 - A_clamp.pow(2) + 1e-6).sum(-1)
            ratio = torch.exp(nlp - OL)
            aloss = -torch.min(ratio*Adv, torch.clamp(ratio, 1-clip, 1+clip)*Adv).mean()
            vpred = policy.critic_forward(S).squeeze(-1)
            closs = dynamic_nonconvex_loss(vpred, Gt, epoch=ep)
            entropy = dist.entropy().sum(-1).mean()
            loss = aloss + vf_coef * closs - ent_coef * entropy
            opt.zero_grad(); loss.backward()
            if max_grad_norm>0: torch.nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm)
            opt.step()
        print(f"[PPO Teacher/{policy.mode}] Ep {ep+1}/{episodes} return={Gt.mean():.4f}")
    return policy


# ============================================================
# PPO Distillation (Teacher → Student)
# ============================================================
def ppo_distill_loss(mu_s, mu_t, v_s, v_t, temp=1.0, alpha=0.7, std=0.1):
    var = (std * temp) ** 2
    kl = ((mu_s - mu_t) ** 2 / (2 * var)).mean()
    v_mse = F.mse_loss(v_s, v_t)
    return alpha * kl + (1 - alpha) * v_mse

def ppo_distill(teacher, student, make_batch, d, epochs=3, lr=1e-4, temp=1.0, alpha=0.7):
    teacher.eval()
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    for ep in range(epochs):
        student.set_epoch(ep)
        total, nb = 0.0, 0
        while True:
            batch = make_batch()
            if batch is None: break
            z = batch["state"].to(d)
            with torch.no_grad():
                mu_t = teacher.actor_forward(z)
                _, v_t = teacher.act_deterministic(z)
            mu_s = student.actor_forward(z)
            _, v_s = student.act_deterministic(z)
            loss = ppo_distill_loss(mu_s, mu_t, v_s, v_t, temp=temp, alpha=alpha, std=teacher._std)
            opt.zero_grad(); loss.backward(); opt.step()
            total += loss.item(); nb += 1
        print(f"[PPO Distill/{teacher.mode}] Ep {ep+1}/{epochs} loss={total/max(nb,1):.6f}")
    return student


# ============================================================
# Evaluation (Student critic vs Monte-Carlo)
# ============================================================
@torch.no_grad()
def evaluate_student_with_value_metrics(student, make_batch, env, immune_model, d,
                                        episodes=2, multistep_steps=5, gamma=0.99):
    v_list, G_list, returns = [], [], []
    for ep in range(episodes):
        student.set_epoch(ep); env.set_epoch(ep)
        while True:
            batch = make_batch()
            if batch is None: break
            z = batch["state"].to(d)
            pep,tcr,all_idx,cytok = [batch[k].to(d) for k in ("pep","tcr","all","cytok_init")]
            disc = torch.zeros(z.size(0), device=d); cy = cytok
            for t in range(multistep_steps):
                a, v = student.act_deterministic(z)
                r, cy = env.step(immune_model, pep, tcr, all_idx, a, z, d, cy)
                disc += (gamma**t) * r
            _, v0 = student.act_deterministic(z)
            v_list.append(v0.cpu().numpy()); G_list.append(disc.cpu().numpy())
            returns.append(disc.mean().item())
    if not v_list: 
        return 0.0, {"MSE":0,"RMSE":0,"MAE":0,"R2":0,"Pearson":0}
    V = np.concatenate(v_list).ravel(); G = np.concatenate(G_list).ravel()
    return np.mean(returns), regression_metrics(V, G)


# ============================================================
# Main entry
# ============================================================
def main():
    import sys
    if any(a.startswith("-f") for a in sys.argv): sys.argv=[sys.argv[0]]
    parser = argparse.ArgumentParser()
    parser.add_argument("--train", default="train_BA1.txt")
    parser.add_argument("--test", default="test_BA1.txt")
    parser.add_argument("--alleles", default="allelelist.txt")
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--switch_epoch", type=int, default=1)
    # PPO config
    parser.add_argument("--ppo_multistep_steps", type=int, default=5)
    parser.add_argument("--ppo_gamma", type=float, default=0.99)
    parser.add_argument("--ppo_clip", type=float, default=0.2)
    parser.add_argument("--ppo_std", type=float, default=0.1)
    parser.add_argument("--ppo_teacher_width", type=int, default=512)
    parser.add_argument("--ppo_teacher_depth", type=int, default=4)
    parser.add_argument("--ppo_teacher_episodes", type=int, default=2)
    parser.add_argument("--ppo_teacher_epochs", type=int, default=10)
    parser.add_argument("--ppo_teacher_lr", type=float, default=5e-4)
    parser.add_argument("--ppo_student_width", type=int, default=128)
    parser.add_argument("--ppo_student_depth", type=int, default=2)
    parser.add_argument("--ppo_distill_epochs", type=int, default=10)
    parser.add_argument("--ppo_distill_lr", type=float, default=1e-4)
    parser.add_argument("--ppo_distill_temp", type=float, default=1.0)
    parser.add_argument("--ppo_distill_alpha", type=float, default=0.7)
    parser.add_argument("--modes", type=str, default="convex, nonconvex, twostage") #
    args,_ = parser.parse_known_args()

    d = get_device(); seed_everything(1)
    if not HAS_TORCH: return print("❌ PyTorch missing")

    allele_to_id = load_alleles(args.alleles)
    tr = parse_examples(args.train, allele_to_id)
    ts = parse_examples(args.test, allele_to_id)

    tr_ds = PeptideDataset(tr, allele_to_id); ts_ds = PeptideDataset(ts, allele_to_id)
    N=len(tr_ds); idx=list(range(N)); random.shuffle(idx)
    val=max(1,int(0.2*N))
    tr_dl=DataLoader(Subset(tr_ds,idx[val:]),batch_size=args.batch_size,shuffle=True,collate_fn=collate_pep)
    val_dl=DataLoader(Subset(tr_ds,idx[:val]),batch_size=args.batch_size,shuffle=False,collate_fn=collate_pep)
    ts_dl=DataLoader(ts_ds,batch_size=args.batch_size,shuffle=False,collate_fn=collate_pep)

    # Train ImmuneNet (frozen for PPO)
    immune=ImmuneNet(len(AA_VOCAB),len(allele_to_id),act_mode="twostage",
                     switch_epoch=args.switch_epoch).to(d)
    print("=== Supervised training ImmuneNet ===")
    train_supervised(immune,tr_dl,val_dl,d,epochs=args.epochs)
    print("Test metrics:",eval_metrics(immune,ts_dl,d))

    # Batch builder
    def make_batch_gen():
        def make_batch():
            if not hasattr(make_batch,"it"): make_batch.it=iter(tr_dl)
            try: pep,tcr,all_idx,cytok,y=next(make_batch.it)
            except StopIteration:
                make_batch.it=iter(tr_dl); return None
            with torch.no_grad():
                z=immune.encode_backbone(pep.to(d),tcr.to(d),all_idx.to(d)).detach()
            return {"state":z,"pep":pep,"tcr":tcr,"all":all_idx,"cytok_init":cytok}
        return make_batch

    results=[]
    for mode in [m.strip() for m in args.modes.split(",") if m.strip()]:
        print(f"\n===== Mode {mode.upper()} =====")
        env=AdaptiveCytokineEnv().to(d)
        teacher=PPOPolicy(256,len(CYTOKINES),args.ppo_teacher_width,args.ppo_teacher_depth,
                          mode,args.ppo_std,args.switch_epoch).to(d)
        make_batch=make_batch_gen()
        teacher=ppo_train_multistep(teacher,make_batch,env,immune,d,
                                    args.ppo_teacher_episodes,args.ppo_teacher_epochs,
                                    args.ppo_multistep_steps,args.ppo_gamma,
                                    args.ppo_clip,lr=args.ppo_teacher_lr)

        student=PPOPolicy(256,len(CYTOKINES),args.ppo_student_width,args.ppo_student_depth,
                          mode,args.ppo_std,args.switch_epoch).to(d)
        make_batch=make_batch_gen()
        student=ppo_distill(teacher,student,make_batch,d,args.ppo_distill_epochs,
                            args.ppo_distill_lr,args.ppo_distill_temp,args.ppo_distill_alpha)

        make_batch=make_batch_gen()
        mean_ret,metrics=evaluate_student_with_value_metrics(student,make_batch,env,immune,d,
                                                             2,args.ppo_multistep_steps,args.ppo_gamma)
        print(
          f"[Eval/{mode}] Return={mean_ret:.4f} | "
          f"MSE={metrics['MSE']:.4f} | RMSE={metrics['RMSE']:.4f} | "
          f"MAE={metrics['MAE']:.4f} | R2={metrics['R2']:.3f} | "
          f"Pearson={metrics['Pearson']:.3f}"
        )
        torch.save(student.state_dict(),f"ppo_student_{mode}.pt")
        results.append([mode,mean_ret,metrics["MSE"],metrics["RMSE"],metrics["MAE"],
                        metrics["R2"],metrics["Pearson"]])

    with open("ppo_distilled_comparison.csv","w",newline="") as f:
        w=csv.writer(f); w.writerow(["Mode","Return","MSE","RMSE","MAE","R2","Pearson"]); w.writerows(results)
    print("✅ Saved ppo_distilled_comparison.csv")

if __name__=="__main__":
    try: main()
    except SystemExit: pass
        

🟡 Using CPU.
=== Supervised training ImmuneNet ===
