end to end

In [6]:
import os
import numpy as np
from collections import defaultdict

import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModel

# =========================
# 0. Config
# =========================
CSV_PATH        = r"F:\20251201\ms2_with_ids.csv"  # <-- path to your CSV
DEVICE          = "cuda" if torch.cuda.is_available() else "cpu"

EMB_DIM         = 64           # embedding dimension for MS2 contrastive encoder
BATCH_P_CLASSES = 32           # number of labels per batch
BATCH_K_SPECTRA = 4            # spectra per label per batch
EPOCHS          = 50           # MS2 contrastive training epochs
LR              = 1e-3
TEMPERATURE     = 0.07

RANDOM_SEED     = 42

# ---- per-label controls ----
MAX_SPECTRA_PER_LABEL = 1000    # hard cap per label
MIN_SPECTRA_PER_LABEL = 100     # drop labels with fewer than this

# ---- configurable label frequency rank range (1-based) ----
START_RANK      = 1            # e.g. 1
END_RANK        = 11           # e.g. 11 (match your ranks_1_11 checkpoint if you want)

# ---- PLM fine-tuning config ----
PLM_NAME        = "facebook/esm2_t6_8M_UR50D"
MAX_SEQ_LEN     = 1024
PLM_BATCH_SIZE  = 8
PLM_EPOCHS      = 20
PLM_LR          = 1e-4
FREEZE_BACKBONE = True  # set False to fine-tune backbone as well

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print("Using device:", DEVICE)
print(f"Configured LABEL rank range: {START_RANK} .. {END_RANK} (1-based, inclusive)")
print(f"MIN_SPECTRA_PER_LABEL = {MIN_SPECTRA_PER_LABEL}")
print(f"MAX_SPECTRA_PER_LABEL = {MAX_SPECTRA_PER_LABEL}")


# =========================
# 1. Load CSV & initial clean
# =========================
print("\nLoading CSV:", CSV_PATH)
df = pd.read_csv(CSV_PATH)

# Identify MS2 columns: cast_0000..cast_01599
ms2_cols = [col for col in df.columns if col.startswith("cast_")]
if len(ms2_cols) == 0:
    raise ValueError("No columns starting with 'cast_' found in CSV.")

ms2_cols = sorted(ms2_cols)  # ensure correct order cast_0000, cast_0001, ...

print(f"Found {len(ms2_cols)} MS2 columns:", ms2_cols[:5], "...")

# Labels from 'Accession'
if "Accession" not in df.columns:
    raise ValueError("CSV does not contain 'Accession' column.")

accession_raw = df["Accession"]

# Keep only rows with non-empty Accession (identified)
mask_identified = accession_raw.notna() & (accession_raw.astype(str).str.strip() != "")
df_id = df[mask_identified].reset_index(drop=True)
accession = df_id["Accession"].astype(str).str.strip().values

print("Total rows in CSV:", len(df))
print("Rows with non-empty Accession (used for supervised training):", len(df_id))

if len(df_id) == 0:
    raise ValueError("No rows with non-empty Accession. Cannot run supervised contrastive training.")

# Extract MS2 intensities
ms2_lib = df_id[ms2_cols].values.astype(np.float32)

# Replace NaN/inf in ms2 with 0
bad_mask = ~np.isfinite(ms2_lib)
if bad_mask.any():
    print("Found NaN/inf in ms2_lib, replacing with 0")
    ms2_lib[bad_mask] = 0.0

# Per-spectrum max normalization
max_intensity = np.max(ms2_lib, axis=1, keepdims=True)
max_intensity[max_intensity == 0] = 1.0
ms2_lib = ms2_lib / max_intensity

print("Any NaN left in ms2_lib?", np.isnan(ms2_lib).any())
print("Any inf left in ms2_lib?", np.isinf(ms2_lib).any())

N_total, n_bins = ms2_lib.shape
print(f"Spectra shape before per-label construction: (N, n_bins) = {ms2_lib.shape}")


# =========================
# 2. Build label = Accession string and enforce MIN/MAX per label
# =========================
label_str = np.array(accession, dtype=object)
print("Example raw labels (Accession):", label_str[:5])

print("\nApplying per-label limits:")
label_to_indices = defaultdict(list)
for i, lab in enumerate(label_str):
    label_to_indices[lab].append(i)

keep_indices = []
dropped_too_few = 0
capped_too_many = 0

for lab, idxs in label_to_indices.items():
    n = len(idxs)
    if n < MIN_SPECTRA_PER_LABEL:
        dropped_too_few += 1
        continue  # drop this label entirely
    if n > MAX_SPECTRA_PER_LABEL:
        capped_too_many += 1
        chosen = np.random.choice(idxs, size=MAX_SPECTRA_PER_LABEL, replace=False)
        keep_indices.extend(chosen.tolist())
    else:
        keep_indices.extend(idxs)

keep_indices = np.array(keep_indices, dtype=int)

print(f"Labels dropped for having < {MIN_SPECTRA_PER_LABEL} spectra: {dropped_too_few}")
print(f"Labels capped at {MAX_SPECTRA_PER_LABEL} spectra: {capped_too_many}")
print("Total spectra after per-label MIN/MAX filter:", len(keep_indices))

if len(keep_indices) == 0:
    raise ValueError(
        "No spectra left after applying MIN_SPECTRA_PER_LABEL and MAX_SPECTRA_PER_LABEL. "
        "Relax thresholds or check dataset."
    )

# Apply per-label filter
ms2_lib   = ms2_lib[keep_indices]
label_str = label_str[keep_indices]

N_total, n_bins = ms2_lib.shape
print(f"Spectra shape after per-label MIN/MAX: (N, n_bins) = {ms2_lib.shape}")

# Keep metadata aligned (including sequence)
df_id_filtered = df_id.iloc[keep_indices].reset_index(drop=True)


# =========================
# 3. Restrict to chosen LABEL rank range (frequency-based)
# =========================
all_unique_lbl, all_lbl_ids = np.unique(label_str, return_inverse=True)
counts = np.bincount(all_lbl_ids)
print("Total unique labels before rank-frequency filter:", len(all_unique_lbl))

# Sort labels by frequency (descending)
sorted_class_indices = np.argsort(-counts)  # 0..(n_classes-1), high→low frequency
n_classes_total = len(sorted_class_indices)

# Clip requested rank range to available classes
start_rank = max(1, START_RANK)
end_rank   = min(END_RANK, n_classes_total)

if end_rank < start_rank:
    raise ValueError(
        f"No valid LABEL ranks after clipping to available classes. "
        f"Requested {START_RANK}-{END_RANK}, available 1-{n_classes_total}."
    )

start_idx = start_rank - 1  # inclusive
end_idx   = end_rank        # exclusive in slicing

block_indices = sorted_class_indices[start_idx:end_idx]
print(f"\nUsing LABEL ranks {start_rank}..{end_rank} (1-based).")
print(f"Number of LABEL classes in this block: {len(block_indices)}")

if len(block_indices) == 0:
    raise ValueError(
        "No LABEL classes in the requested rank range. "
        "Try a different range (e.g. 1..50)."
    )

block_set = set(block_indices.tolist())

# Filter spectra to only these labels
mask_block        = np.isin(all_lbl_ids, list(block_set))
label_str_block   = label_str[mask_block]
ms2_lib_block     = ms2_lib[mask_block]
lbl_ids_block_raw = all_lbl_ids[mask_block]

print("Spectra after LABEL-block (rank) filter:", len(label_str_block))

# Align metadata for this block
df_block = df_id_filtered[mask_block].reset_index(drop=True)

# Re-encode LABELs compactly 0..(K-1)
unique_lbl_block, lbl_ids_block = np.unique(label_str_block, return_inverse=True)
unique_lbl_block = np.array(unique_lbl_block, dtype=object)
n_classes = len(unique_lbl_block)
print("Number of unique LABELs in block:", n_classes)

N, n_bins = ms2_lib_block.shape
print(f"Final spectra shape after LABEL block filter: (N, n_bins) = {ms2_lib_block.shape}")


# =========================
# 4. Use ALL filtered spectra as training data (NO test split)
# =========================
train_ms2   = ms2_lib_block
train_lblid = lbl_ids_block

print(f"\nTraining on ALL spectra after filtering: {len(train_ms2)}")


# =========================
# 5. Dataset & balanced sampler
# =========================
class SpectraDataset(Dataset):
    def __init__(self, ms2, labels):
        self.ms2    = torch.tensor(ms2,   dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return self.ms2.shape[0]

    def __getitem__(self, idx):
        return {
            "spectrum": self.ms2[idx],
            "label":    self.labels[idx]
        }


class ClassBalancedBatchSampler(Sampler):
    """
    Samples batches with:
      - classes_per_batch distinct labels
      - samples_per_class spectra for each label
    """
    def __init__(self, labels, classes_per_batch=32, samples_per_class=4):
        self.labels = np.array(labels, dtype=np.int64)
        self.classes_per_batch = classes_per_batch
        self.samples_per_class = samples_per_class

        self.class_to_indices = defaultdict(list)
        for i, lab in enumerate(self.labels):
            self.class_to_indices[lab].append(i)

        self.unique_classes = np.array(list(self.class_to_indices.keys()))
        self.batch_size = self.classes_per_batch * self.samples_per_class

    def __iter__(self):
        while True:
            if len(self.unique_classes) <= self.classes_per_batch:
                chosen_classes = self.unique_classes
            else:
                chosen_classes = np.random.choice(
                    self.unique_classes,
                    size=self.classes_per_batch,
                    replace=False
                )

            batch_indices = []
            for c in chosen_classes:
                idxs = self.class_to_indices[c]
                if len(idxs) >= self.samples_per_class:
                    chosen = np.random.choice(idxs, size=self.samples_per_class, replace=False)
                else:
                    chosen = np.random.choice(idxs, size=self.samples_per_class, replace=True)
                batch_indices.extend(chosen.tolist())
            yield batch_indices

    def __len__(self):
        # approximate number of batches per epoch
        return max(1, len(self.labels) // self.batch_size)


train_dataset = SpectraDataset(train_ms2, train_lblid)
train_sampler = ClassBalancedBatchSampler(
    labels=train_lblid,
    classes_per_batch=BATCH_P_CLASSES,
    samples_per_class=BATCH_K_SPECTRA
)
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler
)


# =========================
# 6. Model: Spectrum Encoder
# =========================
class SpectrumEncoder(nn.Module):
    def __init__(self, n_bins, emb_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_bins, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, emb_dim),
        )

    def forward(self, x):
        z = self.net(x)
        if torch.isnan(z).any():
            print("Warning: NaN detected in encoder output, applying nan_to_num.")
            z = torch.nan_to_num(z, nan=0.0, posinf=1e6, neginf=-1e6)
        z = F.normalize(z, dim=-1)
        return z


# =========================
# 7. Safe supervised contrastive loss
# =========================
def supervised_contrastive_loss_safe(emb, labels, temperature=0.07):
    """
    emb    : (B, D) L2-normalized embeddings
    labels : (B,) int label ids
    """
    device = emb.device
    B, D = emb.shape
    if B <= 1:
        return torch.tensor(0.0, device=device)

    # Cosine similarity
    sim = emb @ emb.t()
    sim = torch.clamp(sim / temperature, min=-50.0, max=50.0)

    total_loss = 0.0
    valid_anchors = 0

    for i in range(B):
        sim_i = sim[i].clone()
        sim_i[i] = float("-inf")

        pos_mask_i = (labels == labels[i]) & (torch.arange(B, device=device) != i)
        if not pos_mask_i.any():
            continue

        finite_mask = torch.isfinite(sim_i)
        if not finite_mask.any():
            continue

        max_sim = sim_i[finite_mask].max()
        shifted = sim_i - max_sim
        shifted[~finite_mask] = float("-inf")

        exp_all = torch.exp(shifted)
        denom = exp_all.sum()
        if not torch.isfinite(denom) or denom <= 0:
            continue

        exp_pos = exp_all * pos_mask_i.float()
        num = exp_pos.sum()
        if not torch.isfinite(num) or num <= 0:
            continue

        loss_i = -torch.log(num / denom)
        if not torch.isfinite(loss_i):
            continue

        total_loss += loss_i
        valid_anchors += 1

    if valid_anchors == 0:
        return torch.tensor(0.0, device=device)

    loss = total_loss / valid_anchors
    return loss


# =========================
# 8. Training loop (contrastive spectrum encoder)
# =========================
model = SpectrumEncoder(n_bins=n_bins, emb_dim=EMB_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

steps_per_epoch = len(train_loader)
print("\nsteps_per_epoch:", steps_per_epoch)

# Sanity check
print("Running a sanity check on one batch...")
batch_example = next(iter(train_loader))
x_ex = batch_example["spectrum"].to(DEVICE)
y_ex = batch_example["label"].to(DEVICE)
with torch.no_grad():
    emb_ex = model(x_ex)
    print("Any NaN in emb_ex?", torch.isnan(emb_ex).any().item())
    loss_ex = supervised_contrastive_loss_safe(emb_ex, y_ex, temperature=TEMPERATURE)
    print("Initial loss (sanity check):", loss_ex.item())

print("\nStarting training (contrastive spectrum encoder)...")
spectrum_epoch_losses = []
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    effective_steps = 0

    for step, batch in enumerate(train_loader):
        if step >= steps_per_epoch:
            break  # prevent infinite epoch

        x = batch["spectrum"].to(DEVICE)
        y = batch["label"].to(DEVICE)

        emb = model(x)
        loss = supervised_contrastive_loss_safe(emb, y, temperature=TEMPERATURE)

        if torch.isnan(loss) or not torch.isfinite(loss):
            print(f"NaN/inf loss at epoch {epoch+1}, step {step+1}. Skipping.")
            continue

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        effective_steps += 1

    avg_loss = running_loss / max(1, effective_steps)
    spectrum_epoch_losses.append(avg_loss)
    print(f"[Spectrum] Epoch {epoch+1}/{EPOCHS}  Avg loss: {avg_loss:.4f}")

print("Contrastive training done.")
print("Spectrum encoder losses per epoch:", spectrum_epoch_losses)


# =========================
# 9. Build train embedding index
# =========================
model.eval()
with torch.no_grad():
    train_emb = []
    batch_size_eval = 2048
    for i in range(0, len(train_ms2), batch_size_eval):
        batch = torch.tensor(
            train_ms2[i:i+batch_size_eval],
            dtype=torch.float32,
            device=DEVICE
        )
        z = model(batch)
        train_emb.append(z.cpu().numpy())
    train_emb = np.concatenate(train_emb, axis=0)

print("\nTrain embeddings shape:", train_emb.shape)
train_labels = train_lblid


# =========================
# 10. Save spectrum encoder model
# =========================
MODEL_PATH = (
    f"contrastive_spectrum_encoder_Accession_"
    f"min{MIN_SPECTRA_PER_LABEL}_max{MAX_SPECTRA_PER_LABEL}_"
    f"ranks_{start_rank}_{end_rank}.pth"
)

torch.save({
    "model_state_dict": model.state_dict(),
    "emb_dim": EMB_DIM,
    "n_bins": n_bins,
    "unique_label_block": list(unique_lbl_block),  # accession labels for this block
    "rank_start": start_rank,
    "rank_end": end_rank,
    "min_spectra_per_label": MIN_SPECTRA_PER_LABEL,
    "max_spectra_per_label": MAX_SPECTRA_PER_LABEL,
}, MODEL_PATH)

print(f"Contrastive spectrum model saved to {MODEL_PATH}")


# =========================
# 11. Save spectrum embedding index for retrieval
# =========================
acc_per_spectrum = np.array(
    [str(unique_lbl_block[label]) for label in train_labels],
    dtype=object
)

INDEX_PATH = r"train_spectrum_embeddings.npz"
np.savez(
    INDEX_PATH,
    embeddings=train_emb,         # (N, 64)
    labels=train_labels,          # (N,)
    accessions=acc_per_spectrum   # (N,)
)
print(f"Spectrum embedding index saved to {INDEX_PATH}")


# =========================
# 12. Build per-Accession target embeddings for PLM
# =========================
if "sequence" not in df_block.columns:
    raise ValueError(
        "df_block does not have a 'sequence' column. "
        "You must supply sequences in the CSV."
    )

acc_block = df_block["Accession"].astype(str).str.strip().values
seq_block = df_block["sequence"].astype(str).values  # one sequence per spectrum row

n_labels_plm = n_classes  # same as number of unique_lbl_block

label_ids = []
target_embs = []
seqs_for_labels = []

for lab_id in range(n_labels_plm):
    acc = str(unique_lbl_block[lab_id])

    # indices of spectra belonging to this label
    idxs = np.where(train_labels == lab_id)[0]
    if len(idxs) == 0:
        continue

    # mean spectrum embedding for this Accession (target for regression)
    mean_emb = train_emb[idxs].mean(axis=0)

    # choose a representative sequence: e.g., the longest among all spectra with this accession
    mask_acc = (acc_block == acc)
    seqs = seq_block[mask_acc]
    if len(seqs) == 0:
        continue

    seq_rep = max(seqs, key=len)

    label_ids.append(lab_id)
    target_embs.append(mean_emb)
    seqs_for_labels.append(seq_rep)

label_ids = np.array(label_ids, dtype=int)
target_embs = np.stack(target_embs, axis=0).astype(np.float32)

print(f"\nPer-Accession targets for PLM:")
print("  #labels with both sequence and embedding:", len(label_ids))
print("  target_embs shape:", target_embs.shape)


# =========================
# 13. Protein LM → MS2 embedding regressor
# =========================
print("\nLoading protein language model:", PLM_NAME)
tokenizer = AutoTokenizer.from_pretrained(PLM_NAME)
backbone = AutoModel.from_pretrained(PLM_NAME)


class ProteinToSpectrumRegressor(nn.Module):
    def __init__(self, backbone, emb_dim, freeze_backbone=True):
        super().__init__()
        self.backbone = backbone
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        hidden_size = backbone.config.hidden_size
        self.reg_head = nn.Linear(hidden_size, emb_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        token_embs = outputs.last_hidden_state   # (B, L, H)
        cls = token_embs[:, 0, :]               # (B, H)
        z = self.reg_head(cls)                  # (B, emb_dim)
        z = F.normalize(z, dim=-1)              # keep in same manifold as train_emb
        return z


plm_model = ProteinToSpectrumRegressor(
    backbone,
    emb_dim=EMB_DIM,
    freeze_backbone=FREEZE_BACKBONE
).to(DEVICE)


class ProteinSeqDataset(Dataset):
    def __init__(self, sequences, targets, tokenizer, max_len=1024):
        self.sequences = list(sequences)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        tgt = self.targets[idx]

        enc = self.tokenizer(
            seq,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=self.max_len
        )
        item = {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "target": tgt
        }
        return item


def cosine_regression_loss(pred, target):
    """
    Encourage high cosine similarity between predicted and target embeddings.
    pred, target: (B, D)
    """
    pred = F.normalize(pred, dim=-1)
    target = F.normalize(target, dim=-1)
    cos = (pred * target).sum(dim=-1)  # (B,)
    return 1.0 - cos.mean()


plm_dataset = ProteinSeqDataset(
    sequences=seqs_for_labels,
    targets=target_embs,
    tokenizer=tokenizer,
    max_len=MAX_SEQ_LEN
)

plm_loader = DataLoader(
    plm_dataset,
    batch_size=PLM_BATCH_SIZE,
    shuffle=True
)


# =========================
# 14. Train PLM → MS2 regressor
# =========================
optimizer_plm = torch.optim.AdamW(plm_model.parameters(), lr=PLM_LR)

print("\nStarting PLM fine-tuning to predict MS2 embeddings...")
plm_epoch_losses = []
for epoch in range(PLM_EPOCHS):
    plm_model.train()
    running_loss = 0.0
    n_steps = 0

    for batch in plm_loader:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        target = batch["target"].to(DEVICE)

        pred = plm_model(input_ids=input_ids, attention_mask=attention_mask)
        loss = cosine_regression_loss(pred, target)

        optimizer_plm.zero_grad()
        loss.backward()
        optimizer_plm.step()

        running_loss += loss.item()
        n_steps += 1

    avg_loss = running_loss / max(1, n_steps)
    plm_epoch_losses.append(avg_loss)
    print(f"[PLM] Epoch {epoch+1}/{PLM_EPOCHS}  Avg loss: {avg_loss:.4f}")

print("PLM fine-tuning done.")
print("PLM regressor losses per epoch:", plm_epoch_losses)


# =========================
# 15. Save Protein→MS2 embedding model
# =========================
PLM_MODEL_PATH = (
    f"protein_to_ms2_plm_{PLM_NAME.replace('/', '_')}_"
    f"emb{EMB_DIM}_ranks_{start_rank}_{end_rank}.pth"
)

torch.save({
    "plm_name": PLM_NAME,
    "model_state_dict": plm_model.state_dict(),
    "emb_dim": EMB_DIM,
    "rank_start": start_rank,
    "rank_end": end_rank,
    "label_ids": label_ids.tolist(),
    "accessions": [str(unique_lbl_block[i]) for i in label_ids],
}, PLM_MODEL_PATH)

print(f"Protein→MS2 embedding model saved to {PLM_MODEL_PATH}")


# =========================
# 16. Function: embed a sequence with trained PLM
# =========================
def embed_sequence(seq: str):
    enc = tokenizer(
        seq,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=MAX_SEQ_LEN
    )

    input_ids = enc["input_ids"].to(DEVICE)
    attention_mask = enc["attention_mask"].to(DEVICE)

    plm_model.eval()
    with torch.no_grad():
        z = plm_model(input_ids=input_ids, attention_mask=attention_mask)

    return z[0].cpu().numpy()   # (EMB_DIM,)


# =========================
# 17. Quick sanity check on a test sequence
# =========================
test_seq = (
    "MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDRFKHLKTE"
    "AEMKASEDLKKHGATVLTALGGILKKKGKH"
)

z_seq = embed_sequence(test_seq)
print("\nTest sequence embedding shape:", z_seq.shape)
print("First 10 dims:", z_seq[:10])


Using device: cpu
Configured LABEL rank range: 1 .. 11 (1-based, inclusive)
MIN_SPECTRA_PER_LABEL = 100
MAX_SPECTRA_PER_LABEL = 1000

Loading CSV: F:\20251201\ms2_with_ids.csv
Found 1600 MS2 columns: ['cast_00000', 'cast_00001', 'cast_00002', 'cast_00003', 'cast_00004'] ...
Total rows in CSV: 19307
Rows with non-empty Accession (used for supervised training): 4485
Any NaN left in ms2_lib? False
Any inf left in ms2_lib? False
Spectra shape before per-label construction: (N, n_bins) = (4485, 1600)
Example raw labels (Accession): ['P05204' 'Q9Y2I7-2' 'P02686-5' 'P02686-5' 'P02686-5']

Applying per-label limits:
Labels dropped for having < 100 spectra: 58
Labels capped at 1000 spectra: 1
Total spectra after per-label MIN/MAX filter: 3343
Spectra shape after per-label MIN/MAX: (N, n_bins) = (3343, 1600)
Total unique labels before rank-frequency filter: 11

Using LABEL ranks 1..11 (1-based).
Number of LABEL classes in this block: 11
Spectra after LABEL-block (rank) filter: 3343
Number of uni

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Starting PLM fine-tuning to predict MS2 embeddings...
[PLM] Epoch 1/20  Avg loss: 1.0679
[PLM] Epoch 2/20  Avg loss: 1.0262
[PLM] Epoch 3/20  Avg loss: 1.0163
[PLM] Epoch 4/20  Avg loss: 1.0206
[PLM] Epoch 5/20  Avg loss: 1.0074
[PLM] Epoch 6/20  Avg loss: 0.9925
[PLM] Epoch 7/20  Avg loss: 0.9848
[PLM] Epoch 8/20  Avg loss: 0.9405
[PLM] Epoch 9/20  Avg loss: 0.9546
[PLM] Epoch 10/20  Avg loss: 0.9310
[PLM] Epoch 11/20  Avg loss: 0.9325
[PLM] Epoch 12/20  Avg loss: 0.9047
[PLM] Epoch 13/20  Avg loss: 0.8893
[PLM] Epoch 14/20  Avg loss: 0.8957
[PLM] Epoch 15/20  Avg loss: 0.8416
[PLM] Epoch 16/20  Avg loss: 0.8724
[PLM] Epoch 17/20  Avg loss: 0.8416
[PLM] Epoch 18/20  Avg loss: 0.8390
[PLM] Epoch 19/20  Avg loss: 0.8855
[PLM] Epoch 20/20  Avg loss: 0.7956
PLM fine-tuning done.
PLM regressor losses per epoch: [1.067890703678131, 1.0261692106723785, 1.0162689983844757, 1.0205867886543274, 1.007427841424942, 0.9924833178520203, 0.9847925007343292, 0.9405037462711334, 0.9545674324035645, 0

Training from HDF5 file

In [15]:
import os
import numpy as np
from collections import defaultdict

import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler

from transformers import AutoTokenizer, AutoModel
import pandas as pd  # for saving loss curves as CSV


# =========================
# 0. Config
# =========================
H5_PATH        = r"F:\20251115\spectra_h5\combined_annotated.h5"
FASTA_PATH     = r"F:\20251201\human.fasta"  # <-- UPDATE THIS TO YOUR human.fasta

# Folder where ALL outputs will be saved
OUTPUT_DIR     = r"F:\20251201\result"
os.makedirs(OUTPUT_DIR, exist_ok=True)

DEVICE         = "cuda" if torch.cuda.is_available() else "cpu"

EMB_DIM         = 64           # embedding dimension for MS2 contrastive encoder
BATCH_P_CLASSES = 32           # number of labels per batch
BATCH_K_SPECTRA = 4            # spectra per label per batch
EPOCHS          = 50           # MS2 contrastive training epochs
LR              = 1e-3
TEMPERATURE     = 0.07

RANDOM_SEED     = 42

# ---- per-label controls ----
MAX_SPECTRA_PER_LABEL = 100    # hard cap per label
MIN_SPECTRA_PER_LABEL = 1000     # drop labels with fewer than this

# ---- configurable label frequency rank range (1-based) ----
START_RANK      = 1            # e.g. 1
END_RANK        = 11           # e.g. 11

# ---- PLM fine-tuning config ----
PLM_NAME        = "facebook/esm2_t6_8M_UR50D"
MAX_SEQ_LEN     = 1024
PLM_BATCH_SIZE  = 8
PLM_EPOCHS      = 20
PLM_LR          = 1e-4
FREEZE_BACKBONE = True  # set False to fine-tune backbone as well

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print("Using device:", DEVICE)
print(f"Output directory: {OUTPUT_DIR}")
print(f"Configured LABEL rank range: {START_RANK} .. {END_RANK} (1-based, inclusive)")
print(f"MIN_SPECTRA_PER_LABEL = {MIN_SPECTRA_PER_LABEL}")
print(f"MAX_SPECTRA_PER_LABEL = {MAX_SPECTRA_PER_LABEL}")


# =========================
# 1. Helpers
# =========================
def decode_bytes_array(arr):
    """Decode bytes/objects from HDF5 into Python strings."""
    out = []
    for x in arr:
        if isinstance(x, (bytes, bytearray)):
            out.append(x.decode("utf-8", errors="ignore"))
        elif x is None:
            out.append("")
        else:
            out.append(str(x))
    return np.array(out, dtype=object)


def parse_fasta_to_dict(fasta_path):
    """
    Parse UniProt-like FASTA into dict: accession -> AA sequence.

    Handles headers like:
      >sp|P02679-2|FIBG_HUMAN ...
      >tr|Q9H0H5|SOME_PROT ...
      >P02679-2 some description
    """
    acc_to_seq = {}
    header = None
    seq_chunks = []

    def flush_entry(h, chunks):
        if h is None or not chunks:
            return
        # get first token before space
        tok0 = h.split()[0]
        parts = tok0.split("|")
        if len(parts) == 1:
            acc = parts[0]
        else:
            # UniProt: sp|ACC|NAME
            acc = parts[1] if parts[1] else parts[0]
        seq = "".join(chunks).strip().upper()
        if acc and seq:
            acc_to_seq[acc] = seq

    with open(fasta_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                # flush previous
                flush_entry(header, seq_chunks)
                header = line[1:]
                seq_chunks = []
            else:
                seq_chunks.append(line)

        # flush last
        flush_entry(header, seq_chunks)

    return acc_to_seq


# =========================
# 2. Load HDF5 & initial clean
# =========================
print("\nLoading HDF5:", H5_PATH)
with h5py.File(H5_PATH, "r") as h5:
    print("Keys inside the H5 file:")
    for key in h5.keys():
        print(" -", key)

    # spectra
    ms2_lib = h5["ms2_lib"][:].astype(np.float32)

    # labels
    accession_raw = decode_bytes_array(h5["Accession"][:])

print("ms2_lib shape:", ms2_lib.shape)
print("Accession array shape:", accession_raw.shape)

# Keep only rows with non-empty Accession (identified)
mask_identified = np.array(
    [s is not None and str(s).strip() not in ("", "nan", "None") for s in accession_raw],
    dtype=bool
)

ms2_lib = ms2_lib[mask_identified]
accession = accession_raw[mask_identified]

print("Total rows in H5:", len(accession_raw))
print("Rows with non-empty Accession (used for supervised training):", len(accession))

if len(accession) == 0:
    raise ValueError("No rows with non-empty Accession. Cannot run supervised contrastive training.")

# Replace NaN/inf in ms2 with 0
bad_mask = ~np.isfinite(ms2_lib)
if bad_mask.any():
    print("Found NaN/inf in ms2_lib, replacing with 0")
    ms2_lib[bad_mask] = 0.0

# Per-spectrum max normalization
max_intensity = np.max(ms2_lib, axis=1, keepdims=True)
max_intensity[max_intensity == 0] = 1.0
ms2_lib = ms2_lib / max_intensity

print("Any NaN left in ms2_lib?", np.isnan(ms2_lib).any())
print("Any inf left in ms2_lib?", np.isinf(ms2_lib).any())

N_total, n_bins = ms2_lib.shape
print(f"Spectra shape before per-label construction: (N, n_bins) = {ms2_lib.shape}")


# =========================
# 3. Build label = Accession string and enforce MIN/MAX per label
# =========================
label_str = np.array(accession, dtype=object)

print("Example raw labels (Accession):", label_str[:5])

print("\nApplying per-label limits:")
label_to_indices = defaultdict(list)
for i, lab in enumerate(label_str):
    label_to_indices[lab].append(i)

keep_indices = []
dropped_too_few = 0
capped_too_many = 0

for lab, idxs in label_to_indices.items():
    n = len(idxs)
    if n < MIN_SPECTRA_PER_LABEL:
        dropped_too_few += 1
        continue  # drop this label entirely
    if n > MAX_SPECTRA_PER_LABEL:
        capped_too_many += 1
        chosen = np.random.choice(idxs, size=MAX_SPECTRA_PER_LABEL, replace=False)
        keep_indices.extend(chosen.tolist())
    else:
        keep_indices.extend(idxs)

keep_indices = np.array(keep_indices, dtype=int)

print(f"Labels dropped for having < {MIN_SPECTRA_PER_LABEL} spectra: {dropped_too_few}")
print(f"Labels capped at {MAX_SPECTRA_PER_LABEL} spectra: {capped_too_many}")
print("Total spectra after per-label MIN/MAX filter:", len(keep_indices))

if len(keep_indices) == 0:
    raise ValueError(
        "No spectra left after applying MIN_SPECTRA_PER_LABEL and MAX_SPECTRA_PER_LABEL. "
        "Relax thresholds or check dataset."
    )

# Apply per-label filter
ms2_lib   = ms2_lib[keep_indices]
label_str = label_str[keep_indices]

N_total, n_bins = ms2_lib.shape
print(f"Spectra shape after per-label MIN/MAX: (N, n_bins) = {ms2_lib.shape}")


# =========================
# 4. Restrict to chosen LABEL rank range (frequency-based)
# =========================
all_unique_lbl, all_lbl_ids = np.unique(label_str, return_inverse=True)
counts = np.bincount(all_lbl_ids)
print("Total unique labels before rank-frequency filter:", len(all_unique_lbl))

# Sort labels by frequency (descending)
sorted_class_indices = np.argsort(-counts)  # 0..(n_classes-1), high→low frequency
n_classes_total = len(sorted_class_indices)

# Clip requested rank range to available classes
start_rank = max(1, START_RANK)
end_rank   = min(END_RANK, n_classes_total)

if end_rank < start_rank:
    raise ValueError(
        f"No valid LABEL ranks after clipping to available classes. "
        f"Requested {START_RANK}-{END_RANK}, available 1-{n_classes_total}."
    )

start_idx = start_rank - 1  # inclusive
end_idx   = end_rank        # exclusive in slicing

block_indices = sorted_class_indices[start_idx:end_idx]
print(f"\nUsing LABEL ranks {start_rank}..{end_rank} (1-based).")
print(f"Number of LABEL classes in this block: {len(block_indices)}")

if len(block_indices) == 0:
    raise ValueError(
        "No LABEL classes in the requested rank range. "
        "Try a different range (e.g. 1..50)."
    )

block_set = set(block_indices.tolist())

mask_block        = np.isin(all_lbl_ids, list(block_set))
label_str_block   = label_str[mask_block]
ms2_lib_block     = ms2_lib[mask_block]

print("Spectra after LABEL-block (rank) filter:", len(label_str_block))

# Re-encode LABELs compactly 0..(K-1)
unique_lbl_block, lbl_ids_block = np.unique(label_str_block, return_inverse=True)
unique_lbl_block = np.array(unique_lbl_block, dtype=object)
n_classes = len(unique_lbl_block)
print("Number of unique LABELs in block:", n_classes)

N, n_bins = ms2_lib_block.shape
print(f"Final spectra shape after LABEL block filter: (N, n_bins) = {ms2_lib_block.shape}")


# =========================
# 5. Use ALL filtered spectra as training data (NO test split)
# =========================
train_ms2   = ms2_lib_block
train_lblid = lbl_ids_block

print(f"\nTraining on ALL spectra after filtering: {len(train_ms2)}")


# =========================
# 6. Dataset & balanced sampler
# =========================
class SpectraDataset(Dataset):
    def __init__(self, ms2, labels):
        self.ms2    = torch.tensor(ms2,   dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return self.ms2.shape[0]

    def __getitem__(self, idx):
        return {
            "spectrum": self.ms2[idx],
            "label":    self.labels[idx]
        }


class ClassBalancedBatchSampler(Sampler):
    """
    Samples batches with:
      - classes_per_batch distinct labels
      - samples_per_class spectra for each label
    """
    def __init__(self, labels, classes_per_batch=32, samples_per_class=4):
        self.labels = np.array(labels, dtype=np.int64)
        self.classes_per_batch = classes_per_batch
        self.samples_per_class = samples_per_class

        self.class_to_indices = defaultdict(list)
        for i, lab in enumerate(self.labels):
            self.class_to_indices[lab].append(i)

        self.unique_classes = np.array(list(self.class_to_indices.keys()))
        self.batch_size = self.classes_per_batch * self.samples_per_class

    def __iter__(self):
        while True:
            if len(self.unique_classes) <= self.classes_per_batch:
                chosen_classes = self.unique_classes
            else:
                chosen_classes = np.random.choice(
                    self.unique_classes,
                    size=self.classes_per_batch,
                    replace=False
                )

            batch_indices = []
            for c in chosen_classes:
                idxs = self.class_to_indices[c]
                if len(idxs) >= self.samples_per_class:
                    chosen = np.random.choice(idxs, size=self.samples_per_class, replace=False)
                else:
                    chosen = np.random.choice(idxs, size=self.samples_per_class, replace=True)
                batch_indices.extend(chosen.tolist())
            yield batch_indices

    def __len__(self):
        # approximate number of batches per epoch
        return max(1, len(self.labels) // self.batch_size)


train_dataset = SpectraDataset(train_ms2, train_lblid)
train_sampler = ClassBalancedBatchSampler(
    labels=train_lblid,
    classes_per_batch=BATCH_P_CLASSES,
    samples_per_class=BATCH_K_SPECTRA
)
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler
)


# =========================
# 7. Model: Spectrum Encoder
# =========================
class SpectrumEncoder(nn.Module):
    def __init__(self, n_bins, emb_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_bins, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, emb_dim),
        )

    def forward(self, x):
        z = self.net(x)
        if torch.isnan(z).any():
            print("Warning: NaN detected in encoder output, applying nan_to_num.")
            z = torch.nan_to_num(z, nan=0.0, posinf=1e6, neginf=-1e6)
        z = F.normalize(z, dim=-1)
        return z


# =========================
# 8. Safe supervised contrastive loss
# =========================
def supervised_contrastive_loss_safe(emb, labels, temperature=0.07):
    """
    emb    : (B, D) L2-normalized embeddings
    labels : (B,) int label ids
    """
    device = emb.device
    B, D = emb.shape
    if B <= 1:
        return torch.tensor(0.0, device=device)

    # Cosine similarity
    sim = emb @ emb.t()
    sim = torch.clamp(sim / temperature, min=-50.0, max=50.0)

    total_loss = 0.0
    valid_anchors = 0

    for i in range(B):
        sim_i = sim[i].clone()
        sim_i[i] = float("-inf")

        pos_mask_i = (labels == labels[i]) & (torch.arange(B, device=device) != i)
        if not pos_mask_i.any():
            continue

        finite_mask = torch.isfinite(sim_i)
        if not finite_mask.any():
            continue

        max_sim = sim_i[finite_mask].max()
        shifted = sim_i - max_sim
        shifted[~finite_mask] = float("-inf")

        exp_all = torch.exp(shifted)
        denom = exp_all.sum()
        if not torch.isfinite(denom) or denom <= 0:
            continue

        exp_pos = exp_all * pos_mask_i.float()
        num = exp_pos.sum()
        if not torch.isfinite(num) or num <= 0:
            continue

        loss_i = -torch.log(num / denom)
        if not torch.isfinite(loss_i):
            continue

        total_loss += loss_i
        valid_anchors += 1

    if valid_anchors == 0:
        return torch.tensor(0.0, device=device)

    loss = total_loss / valid_anchors
    return loss


# =========================
# 9. Training loop (contrastive spectrum encoder)
# =========================
model = SpectrumEncoder(n_bins=n_bins, emb_dim=EMB_DIM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

steps_per_epoch = len(train_loader)
print("\nsteps_per_epoch:", steps_per_epoch)

# Sanity check
print("Running a sanity check on one batch...")
batch_example = next(iter(train_loader))
x_ex = batch_example["spectrum"].to(DEVICE)
y_ex = batch_example["label"].to(DEVICE)
with torch.no_grad():
    emb_ex = model(x_ex)
    print("Any NaN in emb_ex?", torch.isnan(emb_ex).any().item())
    loss_ex = supervised_contrastive_loss_safe(emb_ex, y_ex, temperature=TEMPERATURE)
    print("Initial loss (sanity check):", loss_ex.item())

print("\nStarting training (contrastive spectrum encoder)...")
spectrum_epoch_losses = []
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    effective_steps = 0

    for step, batch in enumerate(train_loader):
        if step >= steps_per_epoch:
            break  # prevent infinite epoch

        x = batch["spectrum"].to(DEVICE)
        y = batch["label"].to(DEVICE)

        emb = model(x)
        loss = supervised_contrastive_loss_safe(emb, y, temperature=TEMPERATURE)

        if torch.isnan(loss) or not torch.isfinite(loss):
            print(f"NaN/inf loss at epoch {epoch+1}, step {step+1}. Skipping.")
            continue

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        effective_steps += 1

    avg_loss = running_loss / max(1, effective_steps)
    spectrum_epoch_losses.append(avg_loss)
    print(f"[Spectrum] Epoch {epoch+1}/{EPOCHS}  Avg loss: {avg_loss:.4f}")

print("Contrastive training done.")
print("Spectrum encoder losses per epoch:", spectrum_epoch_losses)

# Save spectrum loss curve as CSV
spectrum_loss_df = pd.DataFrame({
    "epoch": np.arange(1, EPOCHS + 1),
    "loss": spectrum_epoch_losses
})
spectrum_loss_csv_path = os.path.join(OUTPUT_DIR, "spectrum_epoch_losses.csv")
spectrum_loss_df.to_csv(spectrum_loss_csv_path, index=False)
print(f"Spectrum epoch losses saved to {spectrum_loss_csv_path}")


# =========================
# 10. Build train embedding index
# =========================
model.eval()
with torch.no_grad():
    train_emb = []
    batch_size_eval = 2048
    for i in range(0, len(train_ms2), batch_size_eval):
        batch = torch.tensor(
            train_ms2[i:i+batch_size_eval],
            dtype=torch.float32,
            device=DEVICE
        )
        z = model(batch)
        train_emb.append(z.cpu().numpy())
    train_emb = np.concatenate(train_emb, axis=0)

print("\nTrain embeddings shape:", train_emb.shape)
train_labels = train_lblid


# =========================
# 11. Save spectrum encoder model
# =========================
MODEL_PATH = os.path.join(
    OUTPUT_DIR,
    f"contrastive_spectrum_encoder_Accession_"
    f"min{MIN_SPECTRA_PER_LABEL}_max{MAX_SPECTRA_PER_LABEL}_"
    f"ranks_{start_rank}_{end_rank}.pth"
)

torch.save({
    "model_state_dict": model.state_dict(),
    "emb_dim": EMB_DIM,
    "n_bins": n_bins,
    "unique_label_block": list(unique_lbl_block),  # accession labels for this block
    "rank_start": start_rank,
    "rank_end": end_rank,
    "min_spectra_per_label": MIN_SPECTRA_PER_LABEL,
    "max_spectra_per_label": MAX_SPECTRA_PER_LABEL,
}, MODEL_PATH)

print(f"Contrastive spectrum model saved to {MODEL_PATH}")


# =========================
# 12. Save spectrum embedding index for retrieval
# =========================
acc_per_spectrum = np.array(
    [str(unique_lbl_block[label]) for label in train_labels],
    dtype=object
)

INDEX_PATH = os.path.join(OUTPUT_DIR, "train_spectrum_embeddings.npz")
np.savez(
    INDEX_PATH,
    embeddings=train_emb,         # (N, 64)
    labels=train_labels,          # (N,)
    accessions=acc_per_spectrum   # (N,)
)
print(f"Spectrum embedding index saved to {INDEX_PATH}")


# =========================
# 13. Load Accession → AA sequence from FASTA
# =========================
print("\nParsing FASTA:", FASTA_PATH)
acc_to_seq = parse_fasta_to_dict(FASTA_PATH)
print(f"Found {len(acc_to_seq)} accessions with sequences in FASTA.")

# =========================
# 14. Build per-Accession target embeddings for PLM
# =========================
acc_block = np.array(label_str_block, dtype=object)   # acc per spectrum row in block

n_labels_plm = n_classes  # same as number of unique_lbl_block

label_ids = []
target_embs = []
seqs_for_labels = []

missing_seq = 0

for lab_id in range(n_labels_plm):
    acc = str(unique_lbl_block[lab_id])

    # indices of spectra belonging to this label
    idxs = np.where(train_labels == lab_id)[0]
    if len(idxs) == 0:
        continue

    # mean spectrum embedding for this Accession (target for regression)
    mean_emb = train_emb[idxs].mean(axis=0)

    # look up AA sequence from FASTA
    if acc not in acc_to_seq:
        missing_seq += 1
        continue

    seq_rep = acc_to_seq[acc]  # AA sequence string

    label_ids.append(lab_id)
    target_embs.append(mean_emb)
    seqs_for_labels.append(seq_rep)

label_ids = np.array(label_ids, dtype=int)
target_embs = np.stack(target_embs, axis=0).astype(np.float32)

print(f"\nPer-Accession targets for PLM:")
print("  #labels with both AA sequence and embedding:", len(label_ids))
print("  #labels missing sequence in FASTA:", missing_seq)
print("  target_embs shape:", target_embs.shape)


# =========================
# 15. Protein LM → MS2 embedding regressor
# =========================
print("\nLoading protein language model:", PLM_NAME)
tokenizer = AutoTokenizer.from_pretrained(PLM_NAME)
backbone = AutoModel.from_pretrained(PLM_NAME)


class ProteinToSpectrumRegressor(nn.Module):
    def __init__(self, backbone, emb_dim, freeze_backbone=True):
        super().__init__()
        self.backbone = backbone
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        hidden_size = backbone.config.hidden_size
        self.reg_head = nn.Linear(hidden_size, emb_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        token_embs = outputs.last_hidden_state   # (B, L, H)
        cls = token_embs[:, 0, :]               # (B, H)
        z = self.reg_head(cls)                  # (B, emb_dim)
        z = F.normalize(z, dim=-1)              # keep in same manifold as train_emb
        return z


class ProteinSeqDataset(Dataset):
    def __init__(self, sequences, targets, tokenizer, max_len=1024):
        self.sequences = list(sequences)
        self.targets = torch.tensor(targets, dtype=torch.float32)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]      # AA sequence
        tgt = self.targets[idx]

        enc = self.tokenizer(
            seq,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=self.max_len
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "target": tgt
        }


def cosine_regression_loss(pred, target):
    """
    Encourage high cosine similarity between predicted and target embeddings.
    pred, target: (B, D)
    """
    pred = F.normalize(pred, dim=-1)
    target = F.normalize(target, dim=-1)
    cos = (pred * target).sum(dim=-1)
    return 1.0 - cos.mean()


plm_model = ProteinToSpectrumRegressor(
    backbone,
    emb_dim=EMB_DIM,
    freeze_backbone=FREEZE_BACKBONE
).to(DEVICE)

plm_dataset = ProteinSeqDataset(
    sequences=seqs_for_labels,
    targets=target_embs,
    tokenizer=tokenizer,
    max_len=MAX_SEQ_LEN
)

plm_loader = DataLoader(
    plm_dataset,
    batch_size=PLM_BATCH_SIZE,
    shuffle=True
)


# =========================
# 16. Train PLM → MS2 regressor
# =========================
optimizer_plm = torch.optim.AdamW(plm_model.parameters(), lr=PLM_LR)

print("\nStarting PLM fine-tuning to predict MS2 embeddings...")
plm_epoch_losses = []
for epoch in range(PLM_EPOCHS):
    plm_model.train()
    running_loss = 0.0
    n_steps = 0

    for batch in plm_loader:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        target = batch["target"].to(DEVICE)

        pred = plm_model(input_ids=input_ids, attention_mask=attention_mask)
        loss = cosine_regression_loss(pred, target)

        optimizer_plm.zero_grad()
        loss.backward()
        optimizer_plm.step()

        running_loss += loss.item()
        n_steps += 1

    avg_loss = running_loss / max(1, n_steps)
    plm_epoch_losses.append(avg_loss)
    print(f"[PLM] Epoch {epoch+1}/{PLM_EPOCHS}  Avg loss: {avg_loss:.4f}")

print("PLM fine-tuning done.")
print("PLM regressor losses per epoch:", plm_epoch_losses)

# Save PLM loss curve as CSV
plm_loss_df = pd.DataFrame({
    "epoch": np.arange(1, PLM_EPOCHS + 1),
    "loss": plm_epoch_losses
})
plm_loss_csv_path = os.path.join(OUTPUT_DIR, "plm_epoch_losses.csv")
plm_loss_df.to_csv(plm_loss_csv_path, index=False)
print(f"PLM epoch losses saved to {plm_loss_csv_path}")


# =========================
# 17. Save Protein→MS2 embedding model
# =========================
PLM_MODEL_PATH = os.path.join(
    OUTPUT_DIR,
    f"protein_to_ms2_plm_{PLM_NAME.replace('/', '_')}_"
    f"emb{EMB_DIM}_ranks_{start_rank}_{end_rank}.pth"
)

torch.save({
    "plm_name": PLM_NAME,
    "model_state_dict": plm_model.state_dict(),
    "emb_dim": EMB_DIM,
    "rank_start": start_rank,
    "rank_end": end_rank,
    "label_ids": label_ids.tolist(),
    "accessions": [str(unique_lbl_block[i]) for i in label_ids],
}, PLM_MODEL_PATH)

print(f"Protein→MS2 embedding model saved to {PLM_MODEL_PATH}")


# =========================
# 18. Function: embed a sequence with trained PLM
# =========================
def embed_sequence(seq: str):
    enc = tokenizer(
        seq,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=MAX_SEQ_LEN
    )

    input_ids = enc["input_ids"].to(DEVICE)
    attention_mask = enc["attention_mask"].to(DEVICE)

    plm_model.eval()
    with torch.no_grad():
        z = plm_model(input_ids=input_ids, attention_mask=attention_mask)

    return z[0].cpu().numpy()   # (EMB_DIM,)


# =========================
# 19. Quick sanity check on a test sequence
# =========================
test_seq = (
    "MVLSEGEWQLVLHVWAKVEADVAGHGQDILIRLFKSHPETLEKFDRFKHLKTE"
    "AEMKASEDLKKHGATVLTALGGILKKKGKH"
)

z_seq = embed_sequence(test_seq)
print("\nTest sequence embedding shape:", z_seq.shape)
print("First 10 dims:", z_seq[:10])


Using device: cpu
Output directory: F:\20251201\result
Configured LABEL rank range: 1 .. 11 (1-based, inclusive)
MIN_SPECTRA_PER_LABEL = 1000
MAX_SPECTRA_PER_LABEL = 100

Loading HDF5: F:\20251115\spectra_h5\combined_annotated.h5
Keys inside the H5 file:
 - Accession
 - MASS
 - PFR
 - file_name
 - group_name
 - ms2_lib
 - precursor_mz
 - rt_min
 - scan
 - sequence
ms2_lib shape: (2202567, 1600)
Accession array shape: (2202567,)
Total rows in H5: 2202567
Rows with non-empty Accession (used for supervised training): 398060
Any NaN left in ms2_lib? False
Any inf left in ms2_lib? False
Spectra shape before per-label construction: (N, n_bins) = (398060, 1600)
Example raw labels (Accession): ['P68363-1' 'P02671-1' 'P02671-1' 'P02671-1' 'P02671-1']

Applying per-label limits:
Labels dropped for having < 1000 spectra: 1757
Labels capped at 100 spectra: 58
Total spectra after per-label MIN/MAX filter: 5800
Spectra shape after per-label MIN/MAX: (N, n_bins) = (5800, 1600)
Total unique labels bef

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Starting PLM fine-tuning to predict MS2 embeddings...
[PLM] Epoch 1/20  Avg loss: 1.0140
[PLM] Epoch 2/20  Avg loss: 0.9926
[PLM] Epoch 3/20  Avg loss: 0.9872
[PLM] Epoch 4/20  Avg loss: 0.9666
[PLM] Epoch 5/20  Avg loss: 0.9367
[PLM] Epoch 6/20  Avg loss: 1.0012
[PLM] Epoch 7/20  Avg loss: 0.9849
[PLM] Epoch 8/20  Avg loss: 0.8727
[PLM] Epoch 9/20  Avg loss: 0.9211
[PLM] Epoch 10/20  Avg loss: 0.9400
[PLM] Epoch 11/20  Avg loss: 0.9709
[PLM] Epoch 12/20  Avg loss: 0.9186
[PLM] Epoch 13/20  Avg loss: 0.8627
[PLM] Epoch 14/20  Avg loss: 0.9403
[PLM] Epoch 15/20  Avg loss: 0.8401
[PLM] Epoch 16/20  Avg loss: 0.8725
[PLM] Epoch 17/20  Avg loss: 0.9275
[PLM] Epoch 18/20  Avg loss: 0.8069
[PLM] Epoch 19/20  Avg loss: 0.8067
[PLM] Epoch 20/20  Avg loss: 0.8106
PLM fine-tuning done.
PLM regressor losses per epoch: [1.0140445232391357, 0.992554098367691, 0.9872357547283173, 0.9666179716587067, 0.9366526007652283, 1.0011983811855316, 0.9849086701869965, 0.8726727366447449, 0.9210588932037354, 