In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import r2_score, mean_squared_error

# ---------- 1. Load Data ----------
def load_dataset(file_path):
    df = pd.read_csv(file_path, sep="\t", header=None)
    df.columns = ["peptide", "binding_score", "allele", "pseudo_seq"]
    return df

train_df = load_dataset("train_BA1.txt")
test_df = load_dataset("test_BA1.txt")

# ---------- 2. Encode Sequences ----------
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
aa_to_idx = {aa: i + 1 for i, aa in enumerate(AMINO_ACIDS)}  # reserve 0 for padding

def encode_sequence(seq, max_len=15):
    seq = seq[:max_len]
    encoded = [aa_to_idx.get(a, 0) for a in seq]
    return encoded + [0] * (max_len - len(encoded))

# Encode peptides and pseudo sequences
max_pep_len = 15
max_hla_len = 10

train_df["peptide_enc"] = train_df["peptide"].apply(lambda x: encode_sequence(x, max_pep_len))
train_df["hla_enc"] = train_df["pseudo_seq"].apply(lambda x: encode_sequence(x, max_hla_len))

test_df["peptide_enc"] = test_df["peptide"].apply(lambda x: encode_sequence(x, max_pep_len))
test_df["hla_enc"] = test_df["pseudo_seq"].apply(lambda x: encode_sequence(x, max_hla_len))

# ---------- 3. Build Dataset Class ----------
class MHCDataset(Dataset):
    def __init__(self, df):
        self.peptides = torch.tensor(np.vstack(df["peptide_enc"]), dtype=torch.long)
        self.hlas = torch.tensor(np.vstack(df["hla_enc"]), dtype=torch.long)
        self.scores = torch.tensor(df["binding_score"].values, dtype=torch.float32)
    def __len__(self):
        return len(self.scores)
    def __getitem__(self, idx):
        return self.peptides[idx], self.hlas[idx], self.scores[idx]

train_data = MHCDataset(train_df)
test_data = MHCDataset(test_df)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# ---------- 4. Model ----------
class MHCModel(nn.Module):
    def __init__(self, vocab_size=21, emb_dim=32, hidden=64):
        super(MHCModel, self).__init__()
        self.pep_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.hla_emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.conv_pep = nn.Conv1d(emb_dim, hidden, kernel_size=3, padding=1)
        self.conv_hla = nn.Conv1d(emb_dim, hidden, kernel_size=3, padding=1)
        self.fc = nn.Sequential(
            nn.Linear(hidden * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, pep, hla):
        pep = self.pep_emb(pep).permute(0, 2, 1)
        hla = self.hla_emb(hla).permute(0, 2, 1)
        pep_feat = torch.max(self.conv_pep(pep), dim=2)[0]
        hla_feat = torch.max(self.conv_hla(hla), dim=2)[0]
        x = torch.cat((pep_feat, hla_feat), dim=1)
        return self.fc(x).squeeze()

model = MHCModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# ---------- 5. Train ----------
for epoch in range(20):
    model.train()
    total_loss = 0
    for pep, hla, y in train_loader:
        optimizer.zero_grad()
        preds = model(pep, hla)
        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Train Loss = {total_loss/len(train_loader):.4f}")

# ---------- 6. Evaluate ----------
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for pep, hla, y in test_loader:
        preds = model(pep, hla)
        y_true.extend(y.numpy())
        y_pred.extend(preds.numpy())

r2 = r2_score(y_true, y_pred)
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
print(f"Test R²: {r2:.3f}, RMSE: {rmse:.3f}")


Epoch 1: Train Loss = 0.0512
Epoch 2: Train Loss = 0.0472
Epoch 3: Train Loss = 0.0459
Epoch 4: Train Loss = 0.0450
Epoch 5: Train Loss = 0.0444
Epoch 6: Train Loss = 0.0440
Epoch 7: Train Loss = 0.0434
Epoch 8: Train Loss = 0.0433
Epoch 9: Train Loss = 0.0428
Epoch 10: Train Loss = 0.0427
Epoch 11: Train Loss = 0.0424
Epoch 12: Train Loss = 0.0423
Epoch 13: Train Loss = 0.0420
Epoch 14: Train Loss = 0.0420
Epoch 15: Train Loss = 0.0419
Epoch 16: Train Loss = 0.0417
Epoch 17: Train Loss = 0.0416
Epoch 18: Train Loss = 0.0415
Epoch 19: Train Loss = 0.0413
Epoch 20: Train Loss = 0.0413
Test R²: 0.213, RMSE: 0.236
