<a href="https://colab.research.google.com/github/benjaminnigjeh/keyProteoforms/blob/main/Generalized_Transformer_Classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries if needed
# !pip install torch matplotlib scikit-learn

import torch
import torch.nn as nn
import math
import numpy as np
import random
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# --- Positional Encoding ---
class PositionalEncoding1D(nn.Module):
    def __init__(self, d_model, max_len=2000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# --- Transformer Model ---
class SpectrumToSequence(nn.Module):
    def __init__(self, input_dim, seq_len):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, 256)
        self.encoder_pos_enc = PositionalEncoding1D(256, max_len=seq_len)
        encoder_layer = nn.TransformerEncoderLayer(d_model=256, nhead=8, dim_feedforward=512, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
        self.status_head = nn.Linear(256, 3)
        self.delta_head = nn.Sequential(nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, 3))

    def forward(self, spectra):
        x = self.input_proj(spectra).unsqueeze(1).repeat(1, SEQ_LEN, 1)
        x = self.encoder_pos_enc(x)
        x = self.encoder(x)
        return self.status_head(x), self.delta_head(x)

# --- Dataset ---
SEQ_LEN = 140
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY")
AA_MASS = {
    'A': 71.03711, 'C': 103.00919, 'D': 115.02694, 'E': 129.04259, 'F': 147.06841,
    'G': 57.02146, 'H': 137.05891, 'I': 113.08406, 'K': 128.09496, 'L': 113.08406,
    'M': 131.04049, 'N': 114.04293, 'P': 97.05276, 'Q': 128.05858, 'R': 156.10111,
    'S': 87.03203, 'T': 101.04768, 'V': 99.06841, 'W': 186.07931, 'Y': 163.06333
}
canonical_seq = [random.choice(AMINO_ACIDS) for _ in range(SEQ_LEN)]
PTM_MASS_SHIFTS = {'oxidation': 15.9949, 'deamidation': 0.9840, 'phosphorylation': 79.9663}

class PTMDataset(Dataset):
    def __init__(self, n_samples=500, seed=42):
        self.spectra, self.labels, self.deltas = [], [], []
        random.seed(seed)
        np.random.seed(seed)
        for _ in range(n_samples):
            spectrum = np.zeros(2000)
            label = [0] * SEQ_LEN
            delta_mass = [0.0] * SEQ_LEN
            mod_sites = random.sample(range(SEQ_LEN), k=random.randint(3, 8))
            for site in mod_sites:
                ptm_type = random.choice(list(PTM_MASS_SHIFTS.keys()))
                shift = PTM_MASS_SHIFTS[ptm_type]
                label[site] = 1
                delta_mass[site] = shift
                spectrum[site*10] += shift * 5 + np.random.rand()
            missing_sites = random.sample([i for i in range(SEQ_LEN) if label[i] != 1], k=random.randint(1, 4))
            for site in missing_sites:
                label[site] = 2
                delta_mass[site] = 0.0
            for i, aa in enumerate(canonical_seq):
                if label[i] == 0:
                    mass = AA_MASS[aa]
                    spectrum[i*10] += mass * 5 + np.random.rand()
            spectrum += np.random.normal(0, 0.5, size=spectrum.shape)
            self.spectra.append(torch.tensor(np.clip(spectrum, 0, None), dtype=torch.float32))
            self.labels.append(torch.tensor(label, dtype=torch.long))
            self.deltas.append(torch.tensor(delta_mass, dtype=torch.float32))

    def __len__(self): return len(self.spectra)
    def __getitem__(self, idx): return self.spectra[idx], self.labels[idx], self.deltas[idx]

# --- Training ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpectrumToSequence(2000, SEQ_LEN).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_cls = nn.CrossEntropyLoss()
loss_delta = nn.MSELoss()

train_data = PTMDataset(n_samples=500, seed=1)
test_data = PTMDataset(n_samples=100, seed=123)
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)

# Training loop
for epoch in range(100):
    model.train()
    total_loss = 0
    for spectra, status_labels, delta_labels in train_loader:
        spectra, status_labels, delta_labels = spectra.to(DEVICE), status_labels.to(DEVICE), delta_labels.to(DEVICE)
        optimizer.zero_grad()
        status_logits, delta_preds = model(spectra)
        loss1 = loss_cls(status_logits.view(-1, 3), status_labels.view(-1))
        mod_mask = (status_labels == 1)
        if mod_mask.any():
            pred_shift = delta_preds[:, :, 0][mod_mask]
            true_shift = delta_labels[mod_mask]
            loss2 = loss_delta(pred_shift, true_shift)
        else:
            loss2 = torch.tensor(0.0, device=DEVICE)
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# --- Evaluation ---
model.eval()
all_preds, all_true = [], []
delta_mae = 0
delta_count = 0

with torch.no_grad():
    for spectra, status_labels, delta_labels in test_loader:
        spectra, status_labels, delta_labels = spectra.to(DEVICE), status_labels.to(DEVICE), delta_labels.to(DEVICE)
        status_logits, delta_preds = model(spectra)
        preds = status_logits.argmax(dim=-1)
        all_preds.extend(preds.view(-1).cpu().numpy())
        all_true.extend(status_labels.view(-1).cpu().numpy())
        mod_mask = (status_labels == 1)
        if mod_mask.any():
            pred_shift = delta_preds[:, :, 0][mod_mask]
            true_shift = delta_labels[mod_mask]
            delta_mae += (pred_shift - true_shift).abs().sum().item()
            delta_count += pred_shift.numel()

test_acc = np.mean(np.array(all_preds) == np.array(all_true))
test_mae = delta_mae / delta_count if delta_count > 0 else 0.0
print(f"\n✅ Test Accuracy: {test_acc*100:.2f}%")
print(f"✅ Delta Mass MAE: {test_mae:.4f} Da")

# --- Confusion Matrix ---
cm = confusion_matrix(all_true, all_preds, labels=[0, 1, 2])
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Intact", "Modified", "Missing"])
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()


In [6]:
# Put the model in evaluation mode
model.eval()

# Get one sample from the test dataset
spectrum, status_label, delta_label = test_data[0]
spectrum = spectrum.unsqueeze(0).to(DEVICE)  # Add batch dimension

# Forward pass
with torch.no_grad():
    status_logits, delta_preds = model(spectrum)

# Get predicted classes and delta mass
status_pred = status_logits.argmax(dim=-1).squeeze().cpu().numpy()     # shape [SEQ_LEN]
delta_pred = delta_preds.squeeze()[:, 0].cpu().numpy()  # shape [SEQ_LEN]

# Convert ground truth to numpy
status_label = status_label.numpy()
delta_label = delta_label.numpy()

# Display results for selected positions
for i in range(SEQ_LEN):
    print(f"Pos {i:3}: True={status_label[i]}  Pred={status_pred[i]}  "
          f"ΔMass True={delta_label[i]:6.2f}  Pred={delta_pred[i]:6.2f}")


Pos   0: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   1: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   2: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   3: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   4: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   5: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   6: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   7: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   8: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos   9: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  10: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  11: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  12: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  13: True=2  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  14: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  15: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  16: True=0  Pred=0  ΔMass True=  0.00  Pred= 33.19
Pos  17: True=0  Pred=0  ΔMass True=  0.00  Pred