In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import pandas as pd
import ast

# 1. Data Preprocessing

In [2]:
class EpitopeDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]
    
AA_VOCAB = {aa: idx+1 for idx, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')}
AA_VOCAB

{'A': 1,
 'C': 2,
 'D': 3,
 'E': 4,
 'F': 5,
 'G': 6,
 'H': 7,
 'I': 8,
 'K': 9,
 'L': 10,
 'M': 11,
 'N': 12,
 'P': 13,
 'Q': 14,
 'R': 15,
 'S': 16,
 'T': 17,
 'V': 18,
 'W': 19,
 'Y': 20}

In [3]:
def preprocess_data(data_path):
    # read data
    df = pd.read_csv(data_path, sep='|')
    df['Y'] = df['Y'].apply(ast.literal_eval)
    
    # convert aa to index
    X = [[AA_VOCAB.get(aa, 0) for aa in seq] for seq in df['Antigen']]

    # residues to one-hot
    Y = torch.zeros(len(df), max(len(seq) for seq in X))
    for i, labels in enumerate(df['Y']):
        for pos in labels:
            Y[i, pos] = 1

    return EpitopeDataset(X, Y)

In [4]:
dataset = preprocess_data('raw_data/sample_epitope.csv')
dataset

<__main__.EpitopeDataset at 0x1231e2ec0>

# 2. Model Definition

In [9]:
class EpitopeModel(nn.Module):
    def __init__(self, vocab_size, max_len, hidden_size=64, device='cpu'):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, 32)
        self.lstm = nn.LSTM(32, hidden_size, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_size * 2, 1)
        self.sigmoid = nn.Sigmoid()
        self.device = device

    def forward(self, x):
        embed = self.embedding(x)
        lstm_out, _ = self.lstm(embed)
        dropout = self.dropout(lstm_out)
        logits = self.fc(dropout)
        return self.sigmoid(logits)

# 3. Training and Evaluation

In [10]:
def train_epitope_model(dataset, device, epochs=10, batch_size=32):
    train_data, val_data = torch.utils.data.random_split(dataset, [int(len(dataset) * 0.8), len(dataset) - int(len(dataset) * 0.8)])

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size)

    model = EpitopeModel(len(aa_vocab) + 1, dataset.X.size(-1)).to(device)
    optimizer = optim.Adam(model.parameters())
    criterion = nn.BCELoss(reduction='none')

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for X, Y in train_loader:
            X, Y = X.to(device), Y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = criterion(output, Y.float()).mean()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for X, Y in val_loader:
                X, Y = X.to(device), Y.to(device)
                output = model(X)
                loss = criterion(output, Y.float()).mean()
                val_loss += loss.item()

        print(f"Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader)}, Val Loss: {val_loss/len(val_loader)}")

    return model

def predict_epitopes(model, sequence, aa_vocab, threshold=0.5):
    model.eval()
    with torch.no_grad():
        seq_ints = torch.tensor([[aa_vocab.get(aa, 0) for aa in sequence]]).to(model.device)
        output = model(seq_ints)
        epitope_positions = (output > threshold).nonzero(as_tuple=True)[1].tolist()
    return epitope_positions

In [11]:

# Example usage
dataset = preprocess_data('raw_data/epitope.csv')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = train_epitope_model(dataset, device)
model

                                             Antigen  \
0  MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNW...   
1  MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNW...   

                                             Y  
0  37,30,36,39,120,69,63,42,44,25,123,60,13,65  
1  37,30,36,39,120,69,63,42,44,25,123,60,13,65  
0    37,30,36,39,120,69,63,42,44,25,123,60,13,65
1    37,30,36,39,120,69,63,42,44,25,123,60,13,65
Name: Y, dtype: object


ValueError: Using a target size (torch.Size([1, 147])) that is different to the input size (torch.Size([1, 147, 1])) is deprecated. Please ensure they have the same size.

In [13]:
test_sequence = "MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNWVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCAKKIVSDGNGMNAWVAWRNRCKGTDVQAWIRGCRL"
predicted_epitopes = predict_epitopes(model, test_sequence, aa_vocab)
print(f"Predicted epitope positions: {predicted_epitopes}")

Predicted epitope positions: [0]
