In [3]:
import os
import math
import random
from dataclasses import dataclass
from typing import List

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr

# -------------------------------
# Config & Utils
# -------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
AA_TO_IDX = {aa: i + 1 for i, aa in enumerate(AMINO_ACIDS)}  # 0 for pad
IDX_TO_AA = {v: k for k, v in AA_TO_IDX.items()}

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def encode_seq(seq: str, max_len: int) -> List[int]:
    seq = (seq or "")[:max_len]
    enc = [AA_TO_IDX.get(ch, 0) for ch in seq]
    if len(enc) < max_len:
        enc += [0] * (max_len - len(enc))
    return enc

def one_hot_indices(idx_seq: List[int], vocab_size: int=21) -> np.ndarray:
    # idx_seq values in [0..vocab_size-1]; 0=padding -> one-hot zeros
    L = len(idx_seq)
    out = np.zeros((L, vocab_size), dtype=np.float32)
    for i, idx in enumerate(idx_seq):
        if idx > 0:
            out[i, idx] = 1.0
    return out

# -------------------------------
# Data Loading
# -------------------------------
def load_dataset(train_path: str, test_path: str):
    # Expect 4 columns: peptide, binding_aff, mhc_seq, mutant
    train_df = pd.read_csv(train_path, sep="\t", header=None,
                           names=["peptide", "binding_aff", "mhc_seq", "mutant"])
    test_df  = pd.read_csv(test_path, sep="\t", header=None,
                           names=["peptide", "binding_aff", "mhc_seq", "mutant"])

    # Normalize binding to 0..1 for 'cytokine' proxy
    def normalize_col(col):
        mn, mx = col.min(), col.max()
        if mx - mn < 1e-12:
            return np.zeros_like(col, dtype=np.float32)
        return ((col - mn) / (mx - mn)).astype(np.float32)

    train_df["cytokine"] = normalize_col(train_df["binding_aff"].values)
    test_df["cytokine"]  = normalize_col(test_df["binding_aff"].values)

    return train_df, test_df

class PMHCDataset(Dataset):
    def __init__(self, df: pd.DataFrame, max_pep_len: int = 15, max_mhc_len: int = 10):
        self.max_pep = max_pep_len
        self.max_mhc = max_mhc_len
        self.peptides = [encode_seq(p, self.max_pep) for p in df["peptide"].astype(str).tolist()]
        self.mhcs     = [encode_seq(m, self.max_mhc) for m in df["mhc_seq"].astype(str).tolist()]
        self.y        = df["cytokine"].astype(np.float32).values.reshape(-1, 1)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        pep = torch.tensor(self.peptides[idx], dtype=torch.long)
        mhc = torch.tensor(self.mhcs[idx], dtype=torch.long)
        y   = torch.tensor(self.y[idx], dtype=torch.float32)
        return pep, mhc, y

# -------------------------------
# Supervised Model
# -------------------------------
class ImmuneResponseModel(nn.Module):
    def __init__(self, vocab_size=21, emb_dim=64, hidden=128):
        super().__init__()
        self.pep_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.mhc_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.conv_p  = nn.Conv1d(emb_dim, hidden, kernel_size=3, padding=1)
        self.conv_m  = nn.Conv1d(emb_dim, hidden, kernel_size=3, padding=1)
        self.ff = nn.Sequential(
            nn.Linear(hidden*2, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, pep, mhc):
        p = self.pep_emb(pep).permute(0,2,1)     # [B, D, L]
        m = self.mhc_emb(mhc).permute(0,2,1)     # [B, D, L]
        p = torch.max(self.conv_p(p), dim=2)[0]  # [B, H]
        m = torch.max(self.conv_m(m), dim=2)[0]  # [B, H]
        z = torch.cat([p, m], dim=1)
        return self.ff(z)                        # [B,1]

def train_supervised(model, train_loader, val_loader=None, epochs=10, lr=1e-3):
    model.to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    for ep in range(epochs):
        model.train()
        total = 0.0
        for pep, mhc, y in train_loader:
            pep, mhc, y = pep.to(DEVICE), mhc.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            pred = model(pep, mhc)
            loss = mse(pred, y)
            loss.backward()
            opt.step()
            total += loss.item()
        msg = f"[Supervised] Epoch {ep+1}/{epochs} loss={total/len(train_loader):.4f}"
        if val_loader:
            model.eval()
            with torch.no_grad():
                vy_true, vy_pred = [], []
                for pep, mhc, y in val_loader:
                    pred = model(pep.to(DEVICE), mhc.to(DEVICE)).cpu().numpy()
                    vy_pred.extend(pred.reshape(-1))
                    vy_true.extend(y.numpy().reshape(-1))
            v_mse = mean_squared_error(vy_true, vy_pred)
            msg += f" val_mse={v_mse:.4f}"
        print(msg)

# -------------------------------
# Multi-step Mutation Environment
# -------------------------------
class MutationEnv:
    """
    State: one-hot peptide (L_p x 21) + one-hot mhc (L_m x 21) flattened.
    Action: choose (position, amino-acid). Discrete space of size L_p * 20.
    Reward: delta predicted cytokine (new - old) to drive improvements.
    Episode length: max_steps.
    """
    def __init__(self, model: ImmuneResponseModel, mhc_seq: str,
                 max_pep_len: int=15, max_mhc_len: int=10, max_steps: int=8):
        self.model = model.to(DEVICE).eval()
        self.mhc_seq = mhc_seq
        self.max_pep = max_pep_len
        self.max_mhc = max_mhc_len
        self.max_steps = max_steps

        self.action_space_n = self.max_pep * 20  # 20 amino acids
        self.obs_dim = (self.max_pep + self.max_mhc) * 21

        # buffers
        self._pep = ""
        self._mhc_idx = encode_seq(self.mhc_seq, self.max_mhc)
        self._mhc_oh  = one_hot_indices(self._mhc_idx)  # [L_m, 21]
        self._steps = 0
        self._current_score = 0.0

    def _predict(self, peptide: str) -> float:
        pep_t = torch.tensor([encode_seq(peptide, self.max_pep)], dtype=torch.long, device=DEVICE)
        mhc_t = torch.tensor([self._mhc_idx], dtype=torch.long, device=DEVICE)
        with torch.no_grad():
            y = self.model(pep_t, mhc_t).item()
        return float(y)

    def _obs(self, peptide: str) -> torch.Tensor:
        pep_idx = encode_seq(peptide, self.max_pep)
        pep_oh  = one_hot_indices(pep_idx)  # [L_p, 21]
        cat = np.concatenate([pep_oh, self._mhc_oh], axis=0).reshape(-1).astype(np.float32)
        return torch.tensor(cat, dtype=torch.float32, device=DEVICE).unsqueeze(0)  # [1, obs_dim]

    def reset(self, peptide: str):
        self._pep = peptide
        self._steps = 0
        self._current_score = self._predict(peptide)
        return self._obs(peptide)

    def step(self, action: int):
        pos = action // 20
        aa_idx = action % 20
        aa = AMINO_ACIDS[aa_idx]
        pos = min(max(pos, 0), max(0, len(self._pep)-1))
        new_pep = self._pep[:pos] + aa + self._pep[pos+1:]
        new_score = self._predict(new_pep)
        reward = new_score - self._current_score  # delta improvement
        self._pep = new_pep
        self._current_score = new_score
        self._steps += 1
        done = (self._steps >= self.max_steps)
        return self._obs(self._pep), reward, done, {"peptide": self._pep, "score": new_score}

# -------------------------------
# PPO
# -------------------------------
class Actor(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden: int=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, action_dim)
        )
    def forward(self, x):
        return torch.distributions.Categorical(logits=self.net(x))

class Critic(nn.Module):
    def __init__(self, obs_dim: int, hidden: int=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden), nn.Tanh(),
            nn.Linear(hidden, hidden), nn.Tanh(),
            nn.Linear(hidden, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

@dataclass
class PPOConfig:
    epochs: int = 10
    steps_per_epoch: int = 2048
    gamma: float = 0.99
    lam: float = 0.95
    clip_ratio: float = 0.2
    lr_actor: float = 3e-4
    lr_critic: float = 1e-3
    train_iters: int = 10
    max_grad_norm: float = 0.5

def compute_gae(rews, vals, dones, gamma, lam):
    T = len(rews)
    adv = torch.zeros(T, device=DEVICE)
    lastgaelam = 0.0
    for t in reversed(range(T)):
        nextnonterminal = 1.0 - dones[t]
        nextvalue = vals[t+1] if t+1 < len(vals) else 0.0
        delta = rews[t] + gamma * nextvalue * nextnonterminal - vals[t]
        lastgaelam = delta + gamma * lam * nextnonterminal * lastgaelam
        adv[t] = lastgaelam
    ret = adv + vals[:T]
    return adv, ret

def ppo_train(env: MutationEnv, seed_peptides: List[str], cfg: PPOConfig):
    obs_dim = env.obs_dim
    act_dim = env.action_space_n
    actor = Actor(obs_dim, act_dim).to(DEVICE)
    critic = Critic(obs_dim).to(DEVICE)
    opt_a = optim.Adam(actor.parameters(), lr=cfg.lr_actor)
    opt_c = optim.Adam(critic.parameters(), lr=cfg.lr_critic)

    for ep in range(cfg.epochs):
        obs_buf, act_buf, logp_buf, rew_buf, done_buf, val_buf = [], [], [], [], [], []
        # rollouts across seed peptides
        steps = 0
        while steps < cfg.steps_per_epoch:
            pep0 = random.choice(seed_peptides)
            obs = env.reset(pep0)
            done = False
            while not done and steps < cfg.steps_per_epoch:
                with torch.no_grad():
                    pi = actor(obs)
                    a = pi.sample()
                    logp = pi.log_prob(a)
                    v = critic(obs)
                nobs, r, done, info = env.step(int(a.item()))
                # store
                obs_buf.append(obs)
                act_buf.append(a)
                logp_buf.append(logp)
                rew_buf.append(torch.tensor([r], device=DEVICE))
                done_buf.append(torch.tensor([float(done)], device=DEVICE))
                val_buf.append(v)

                obs = nobs
                steps += 1

        # stack
        obs_b = torch.cat(obs_buf, dim=0)
        act_b = torch.stack(act_buf).squeeze(-1)
        logp_b = torch.stack(logp_buf).squeeze(-1)
        rew_b = torch.cat(rew_buf).squeeze(-1)
        done_b = torch.cat(done_buf).squeeze(-1)
        val_b = torch.stack(val_buf).squeeze(-1)

        # bootstrap last value
        with torch.no_grad():
            last_v = critic(obs_b[-1].unsqueeze(0))
        vals_ext = torch.cat([val_b, last_v.reshape(1)])

        adv, ret = compute_gae(rew_b, vals_ext, done_b, cfg.gamma, cfg.lam)
        adv = (adv - adv.mean()) / (adv.std() + 1e-8)

        # optimize
        for _ in range(cfg.train_iters):
            pi = actor(obs_b)
            logp = pi.log_prob(act_b)
            ratio = torch.exp(logp - logp_b)
            surr1 = ratio * adv
            surr2 = torch.clamp(ratio, 1 - cfg.clip_ratio, 1 + cfg.clip_ratio) * adv
            actor_loss = -torch.min(surr1, surr2).mean() - 0.01 * pi.entropy().mean()

            v = critic(obs_b)
            critic_loss = ((v - ret) ** 2).mean()

            opt_a.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(actor.parameters(), cfg.max_grad_norm)
            opt_a.step()

            opt_c.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(critic.parameters(), cfg.max_grad_norm)
            opt_c.step()

        print(f"[PPO] epoch {ep+1}/{cfg.epochs}  actor_loss={actor_loss.item():.4f}  critic_loss={critic_loss.item():.4f}")

    return actor, critic

# -------------------------------
# Evaluation metrics
# -------------------------------
def evaluate_model(model: ImmuneResponseModel, loader: DataLoader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for pep, mhc, y in loader:
            y_hat = model(pep.to(DEVICE), mhc.to(DEVICE)).cpu().numpy().reshape(-1)
            y_true.extend(y.numpy().reshape(-1))
            y_pred.extend(y_hat)
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    mse = mean_squared_error(y_true, y_pred)
    rmse = math.sqrt(mse)
    r2 = r2_score(y_true, y_pred)
    try:
        pearson, _ = pearsonr(y_true, y_pred)
    except Exception:
        pearson = float('nan')
    return {"MSE": mse, "RMSE": rmse, "R2": r2, "Pearson": pearson}

# -------------------------------
# Main
# -------------------------------
def main():
    set_seed(123)
    train_path = "train_BA1.txt"
    test_path  = "test_BA1.txt"
    if not (os.path.exists(train_path) and os.path.exists(test_path)):
        raise FileNotFoundError("Expected train_BA1.txt and test_BA1.txt under /mnt/data/")

    train_df, test_df = load_dataset(train_path, test_path)

    # data loaders
    tr_ds = PMHCDataset(train_df)
    te_ds = PMHCDataset(test_df)
    tr_loader = DataLoader(tr_ds, batch_size=64, shuffle=True)
    te_loader = DataLoader(te_ds, batch_size=128, shuffle=False)

    # supervised model
    model = ImmuneResponseModel().to(DEVICE)
    train_supervised(model, tr_loader, te_loader, epochs=12, lr=1e-3)

    # evaluation
    metrics = evaluate_model(model, te_loader)
    print("\n== Supervised Test Metrics ==")
    for k, v in metrics.items():
        print(f"{k}: {v:.6f}")

    # RL environment with a default MHC from the test set
    default_mhc = str(test_df["mhc_seq"].iloc[0]) if len(test_df) > 0 else "HLA-DPA10103-DPB10201"
    env = MutationEnv(model, mhc_seq=default_mhc, max_steps=8)

    # seed peptides from test set
    seed_peptides = test_df["peptide"].astype(str).sample(min(32, len(test_df)), random_state=123).tolist()

    # train PPO
    ppo_cfg = PPOConfig(epochs=6, steps_per_epoch=1024, train_iters=8)
    actor, critic = ppo_train(env, seed_peptides, ppo_cfg)

    # demonstrate improvements for a few peptides
    print("\n== PPO Mutation Demonstration ==")
    for pep in seed_peptides[:5]:
        obs = env.reset(pep)
        best = {"pep": pep, "score": env._current_score}
        for _ in range(env.max_steps):
            with torch.no_grad():
                pi = actor(obs)
                a = pi.probs.argmax()  # greedy for demo
            obs, r, done, info = env.step(int(a.item()))
            if info["score"] > best["score"]:
                best = {"pep": info["peptide"], "score": info["score"]}
            if done:
                break
        print(f"Start: {pep} | Best mutant: {best['pep']} | Pred cytokine: {best['score']:.4f}")

if __name__ == "__main__":
    main()


[Supervised] Epoch 1/12 loss=0.0474 val_mse=0.0451
[Supervised] Epoch 2/12 loss=0.0425 val_mse=0.0444
[Supervised] Epoch 3/12 loss=0.0410 val_mse=0.0453
[Supervised] Epoch 4/12 loss=0.0401 val_mse=0.0440
[Supervised] Epoch 5/12 loss=0.0394 val_mse=0.0453
[Supervised] Epoch 6/12 loss=0.0389 val_mse=0.0461
[Supervised] Epoch 7/12 loss=0.0383 val_mse=0.0448
[Supervised] Epoch 8/12 loss=0.0379 val_mse=0.0446
[Supervised] Epoch 9/12 loss=0.0376 val_mse=0.0457
[Supervised] Epoch 10/12 loss=0.0373 val_mse=0.0445
[Supervised] Epoch 11/12 loss=0.0369 val_mse=0.0464
[Supervised] Epoch 12/12 loss=0.0366 val_mse=0.0450

== Supervised Test Metrics ==
MSE: 0.045001
RMSE: 0.212134
R2: 0.363033
Pearson: 0.610868
[PPO] epoch 1/6  actor_loss=-0.1475  critic_loss=0.0086
[PPO] epoch 2/6  actor_loss=-0.1315  critic_loss=0.0111
[PPO] epoch 3/6  actor_loss=-0.1272  critic_loss=0.0099
[PPO] epoch 4/6  actor_loss=-0.1246  critic_loss=0.0097
[PPO] epoch 5/6  actor_loss=-0.1258  critic_loss=0.0079
[PPO] epoch 6/